From 1fdd99a9b609a5e906acce414226ff44d75d5975 Mon Sep 17 00:00:00 2001 From: Virag Tripathi <15679776+viragtripathi@users.noreply.github.com> Date: Wed, 11 Feb 2026 17:05:03 -0500 Subject: [PATCH 1/9] feat(cockroachdb): add CockroachDB integration with cockroach-go (#2006) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add support for CockroachDB v25.4.0+ using the official cockroach-go/v2 library for automatic transaction retry. - Add CockroachDB source with ExecuteTxWithRetry using crdbpgx.ExecuteTx - Implement 4 tools: execute-sql, sql, list-tables, list-schemas - Use UUID primary keys (CockroachDB best practice) - Add unit tests for source and all tools - Add integration tests with retry verification - Update Cloud Build configuration for CI Fixes #2005 ## Description > Should include a concise description of the changes (bug or feature), it's > impact, along with a summary of the solution ## PR Checklist > Thank you for opening a Pull Request! Before submitting your PR, there are a > few things you can do to make sure it goes smoothly: - [ ] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) - [ ] Make sure to add `!` if this involve a breaking change 🛠️ Fixes # --------- Co-authored-by: duwenxin99 Co-authored-by: Wenxin Du <117315983+duwenxin99@users.noreply.github.com> --- .ci/integration.cloudbuild.yaml | 32 ++ cmd/root.go | 5 + docs/en/resources/sources/cockroachdb.md | 242 ++++++++++ .../cockroachdb/cockroachdb-execute-sql.md | 273 +++++++++++ .../cockroachdb/cockroachdb-list-schemas.md | 305 ++++++++++++ .../cockroachdb/cockroachdb-list-tables.md | 344 +++++++++++++ .../tools/cockroachdb/cockroachdb-sql.md | 291 +++++++++++ go.mod | 1 + go.sum | 4 + internal/sources/cockroachdb/cockroachdb.go | 430 +++++++++++++++++ .../sources/cockroachdb/cockroachdb_test.go | 224 +++++++++ internal/sources/cockroachdb/security_test.go | 455 ++++++++++++++++++ .../cockroachdbexecutesql.go | 186 +++++++ .../cockroachdbexecutesql_test.go | 81 ++++ .../cockroachdblistschemas.go | 187 +++++++ .../cockroachdblistschemas_test.go | 81 ++++ .../cockroachdblisttables.go | 261 ++++++++++ .../cockroachdblisttables_test.go | 81 ++++ .../cockroachdbsql/cockroachdbsql.go | 192 ++++++++ .../cockroachdbsql/cockroachdbsql_test.go | 93 ++++ .../cockroachdb_integration_test.go | 220 +++++++++ tests/common.go | 35 ++ 22 files changed, 4023 insertions(+) create mode 100644 docs/en/resources/sources/cockroachdb.md create mode 100644 docs/en/resources/tools/cockroachdb/cockroachdb-execute-sql.md create mode 100644 docs/en/resources/tools/cockroachdb/cockroachdb-list-schemas.md create mode 100644 docs/en/resources/tools/cockroachdb/cockroachdb-list-tables.md create mode 100644 docs/en/resources/tools/cockroachdb/cockroachdb-sql.md create mode 100644 internal/sources/cockroachdb/cockroachdb.go create mode 100644 internal/sources/cockroachdb/cockroachdb_test.go create mode 100644 internal/sources/cockroachdb/security_test.go create mode 100644 internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql.go create mode 100644 internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql_test.go create mode 100644 internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas.go create mode 100644 internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas_test.go create mode 100644 internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables.go create mode 100644 internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables_test.go create mode 100644 internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql.go create mode 100644 internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql_test.go create mode 100644 tests/cockroachdb/cockroachdb_integration_test.go diff --git a/.ci/integration.cloudbuild.yaml b/.ci/integration.cloudbuild.yaml index bfa16e7156..96bbbd8b9b 100644 --- a/.ci/integration.cloudbuild.yaml +++ b/.ci/integration.cloudbuild.yaml @@ -354,6 +354,30 @@ steps: postgressql \ postgresexecutesql + - id: "cockroachdb" + name: golang:1 + waitFor: ["compile-test-binary"] + entrypoint: /bin/bash + env: + - "GOPATH=/gopath" + - "COCKROACHDB_DATABASE=$_DATABASE_NAME" + - "COCKROACHDB_PORT=$_COCKROACHDB_PORT" + - "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL" + secretEnv: ["COCKROACHDB_USER", "COCKROACHDB_HOST","CLIENT_ID"] + volumes: + - name: "go" + path: "/gopath" + args: + - -c + - | + .ci/test_with_coverage.sh \ + "CockroachDB" \ + cockroachdb \ + cockroachdbsql \ + cockroachdbexecutesql \ + cockroachdblisttables \ + cockroachdblistschemas + - id: "spanner" name: golang:1 waitFor: ["compile-test-binary"] @@ -1129,6 +1153,11 @@ availableSecrets: env: MARIADB_HOST - versionName: projects/$PROJECT_ID/secrets/mongodb_uri/versions/latest env: MONGODB_URI + - versionName: projects/$PROJECT_ID/secrets/cockroachdb_user/versions/latest + env: COCKROACHDB_USER + - versionName: projects/$PROJECT_ID/secrets/cockroachdb_host/versions/latest + env: COCKROACHDB_HOST + options: logging: CLOUD_LOGGING_ONLY @@ -1189,6 +1218,9 @@ substitutions: _SINGLESTORE_PORT: "3308" _SINGLESTORE_DATABASE: "singlestore" _SINGLESTORE_USER: "root" + _COCKROACHDB_HOST: 127.0.0.1 + _COCKROACHDB_PORT: "26257" + _COCKROACHDB_USER: "root" _MARIADB_PORT: "3307" _MARIADB_DATABASE: test_database _SNOWFLAKE_DATABASE: "test" diff --git a/cmd/root.go b/cmd/root.go index 5e59997211..3d62c11dc8 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -110,6 +110,10 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlpg/cloudsqlpgupgradeprecheck" + _ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdbexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdblistschemas" + _ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdblisttables" + _ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdbsql" _ "github.com/googleapis/genai-toolbox/internal/tools/couchbase" _ "github.com/googleapis/genai-toolbox/internal/tools/dataform/dataformcompilelocal" _ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexlookupentry" @@ -256,6 +260,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql" _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql" _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" + _ "github.com/googleapis/genai-toolbox/internal/sources/cockroachdb" _ "github.com/googleapis/genai-toolbox/internal/sources/couchbase" _ "github.com/googleapis/genai-toolbox/internal/sources/dataplex" _ "github.com/googleapis/genai-toolbox/internal/sources/dgraph" diff --git a/docs/en/resources/sources/cockroachdb.md b/docs/en/resources/sources/cockroachdb.md new file mode 100644 index 0000000000..9ecf884ce5 --- /dev/null +++ b/docs/en/resources/sources/cockroachdb.md @@ -0,0 +1,242 @@ +--- +title: "CockroachDB" +type: docs +weight: 1 +description: > + CockroachDB is a distributed SQL database built for cloud applications. + +--- + +## About + +[CockroachDB][crdb-docs] is a distributed SQL database designed for cloud-native applications. It provides strong consistency, horizontal scalability, and built-in resilience with automatic failover and recovery. CockroachDB uses the PostgreSQL wire protocol, making it compatible with many PostgreSQL tools and drivers while providing unique features like multi-region deployments and distributed transactions. + +**Minimum Version:** CockroachDB v25.1 or later is recommended for full tool compatibility. + +[crdb-docs]: https://www.cockroachlabs.com/docs/ + +## Available Tools + +- [`cockroachdb-sql`](../tools/cockroachdb/cockroachdb-sql.md) + Execute SQL queries as prepared statements in CockroachDB (alias for execute-sql). + +- [`cockroachdb-execute-sql`](../tools/cockroachdb/cockroachdb-execute-sql.md) + Run parameterized SQL statements in CockroachDB. + +- [`cockroachdb-list-schemas`](../tools/cockroachdb/cockroachdb-list-schemas.md) + List schemas in a CockroachDB database. + +- [`cockroachdb-list-tables`](../tools/cockroachdb/cockroachdb-list-tables.md) + List tables in a CockroachDB database. + +## Requirements + +### Database User + +This source uses standard authentication. You will need to [create a CockroachDB user][crdb-users] to login to the database with. For CockroachDB Cloud deployments, SSL/TLS is required. + +[crdb-users]: https://www.cockroachlabs.com/docs/stable/create-user.html + +### SSL/TLS Configuration + +CockroachDB Cloud clusters require SSL/TLS connections. Use the `queryParams` section to configure SSL settings: + +- **For CockroachDB Cloud**: Use `sslmode: require` at minimum +- **For self-hosted with certificates**: Use `sslmode: verify-full` with certificate paths +- **For local development only**: Use `sslmode: disable` (not recommended for production) + +## Example + +```yaml +sources: + my_cockroachdb: + type: cockroachdb + host: your-cluster.cockroachlabs.cloud + port: "26257" + user: myuser + password: mypassword + database: defaultdb + maxRetries: 5 + retryBaseDelay: 500ms + queryParams: + sslmode: require + application_name: my-app + + # MCP Security Settings (recommended for production) + readOnlyMode: true # Read-only by default (MCP best practice) + enableWriteMode: false # Set to true to allow write operations + maxRowLimit: 1000 # Limit query results + queryTimeoutSec: 30 # Prevent long-running queries + enableTelemetry: true # Enable observability + telemetryVerbose: false # Set true for detailed logs + clusterID: "my-cluster" # Optional identifier + +tools: + list_expenses: + type: cockroachdb-sql + source: my_cockroachdb + description: List all expenses + statement: SELECT id, description, amount, category FROM expenses WHERE user_id = $1 + parameters: + - name: user_id + type: string + description: The user's ID + + describe_expenses: + type: cockroachdb-describe-table + source: my_cockroachdb + description: Describe the expenses table schema + + list_expenses_indexes: + type: cockroachdb-list-indexes + source: my_cockroachdb + description: List indexes on the expenses table +``` + +## Configuration Parameters + +### Required Parameters + +| Parameter | Type | Description | +|-----------|------|-------------| +| `type` | string | Must be `cockroachdb` | +| `host` | string | The hostname or IP address of the CockroachDB cluster | +| `port` | string | The port number (typically "26257") | +| `user` | string | The database user name | +| `database` | string | The database name to connect to | + +### Optional Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `password` | string | "" | The database password (can be empty for certificate-based auth) | +| `maxRetries` | integer | 5 | Maximum number of connection retry attempts | +| `retryBaseDelay` | string | "500ms" | Base delay between retry attempts (exponential backoff) | +| `queryParams` | map | {} | Additional connection parameters (e.g., SSL configuration) | + +### MCP Security Parameters + +CockroachDB integration includes security features following the [Model Context Protocol (MCP)](https://modelcontextprotocol.io/) specification: + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `readOnlyMode` | boolean | true | Enables read-only mode by default (MCP requirement) | +| `enableWriteMode` | boolean | false | Explicitly enable write operations (INSERT/UPDATE/DELETE/CREATE/DROP) | +| `maxRowLimit` | integer | 1000 | Maximum rows returned per SELECT query (auto-adds LIMIT clause) | +| `queryTimeoutSec` | integer | 30 | Query timeout in seconds to prevent long-running queries | +| `enableTelemetry` | boolean | true | Enable structured logging of tool invocations | +| `telemetryVerbose` | boolean | false | Enable detailed JSON telemetry output | +| `clusterID` | string | "" | Optional cluster identifier for telemetry | + +### Query Parameters + +Common query parameters for CockroachDB connections: + +| Parameter | Values | Description | +|-----------|--------|-------------| +| `sslmode` | `disable`, `require`, `verify-ca`, `verify-full` | SSL/TLS mode (CockroachDB Cloud requires `require` or higher) | +| `sslrootcert` | file path | Path to root certificate for SSL verification | +| `sslcert` | file path | Path to client certificate | +| `sslkey` | file path | Path to client key | +| `application_name` | string | Application name for connection tracking | + +## Best Practices + +### Security and MCP Compliance + +**Read-Only by Default**: The integration follows MCP best practices by defaulting to read-only mode. This prevents accidental data modifications: + +```yaml +sources: + my_cockroachdb: + readOnlyMode: true # Default behavior + enableWriteMode: false # Explicit write opt-in required +``` + +To enable write operations: + +```yaml +sources: + my_cockroachdb: + readOnlyMode: false # Disable read-only protection + enableWriteMode: true # Explicitly allow writes +``` + +**Query Limits**: Automatic row limits prevent excessive data retrieval: +- SELECT queries automatically get `LIMIT 1000` appended (configurable via `maxRowLimit`) +- Queries are terminated after 30 seconds (configurable via `queryTimeoutSec`) + +**Observability**: Structured telemetry provides visibility into tool usage: +- Tool invocations are logged with status, latency, and row counts +- SQL queries are redacted to protect sensitive values +- Set `telemetryVerbose: true` for detailed JSON logs + +### Use UUID Primary Keys + +CockroachDB performs best with UUID primary keys rather than sequential integers to avoid transaction hotspots: + +```sql +CREATE TABLE expenses ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + description TEXT, + amount DECIMAL(10,2) +); +``` + +### Automatic Transaction Retry + +This source uses the official `cockroach-go/v2` library which provides automatic transaction retry for serialization conflicts. For write operations requiring explicit transaction control, tools can use the `ExecuteTxWithRetry` method. + +### Multi-Region Deployments + +CockroachDB supports multi-region deployments with automatic data distribution. Configure your cluster's regions and survival goals separately from the Toolbox configuration. The source will connect to any node in the cluster. + +### Connection Pooling + +The source maintains a connection pool to the CockroachDB cluster. The pool automatically handles: +- Load balancing across cluster nodes +- Connection retry with exponential backoff +- Health checking of connections + +## Troubleshooting + +### SSL/TLS Errors + +If you encounter "server requires encryption" errors: + +1. For CockroachDB Cloud, ensure `sslmode` is set to `require` or higher: + ```yaml + queryParams: + sslmode: require + ``` + +2. For certificate verification, download your cluster's root certificate and configure: + ```yaml + queryParams: + sslmode: verify-full + sslrootcert: /path/to/ca.crt + ``` + +### Connection Timeouts + +If experiencing connection timeouts: + +1. Check network connectivity to the CockroachDB cluster +2. Verify firewall rules allow connections on port 26257 +3. For CockroachDB Cloud, ensure IP allowlisting is configured +4. Increase `maxRetries` or `retryBaseDelay` if needed + +### Transaction Retry Errors + +CockroachDB may encounter serializable transaction conflicts. The integration automatically handles these retries using the cockroach-go library. If you see retry-related errors, check: + +1. Database load and contention +2. Query patterns that might cause conflicts +3. Consider using `SELECT FOR UPDATE` for explicit locking + +## Additional Resources + +- [CockroachDB Documentation](https://www.cockroachlabs.com/docs/) +- [CockroachDB Best Practices](https://www.cockroachlabs.com/docs/stable/performance-best-practices-overview.html) +- [Multi-Region Capabilities](https://www.cockroachlabs.com/docs/stable/multiregion-overview.html) +- [Connection Parameters](https://www.cockroachlabs.com/docs/stable/connection-parameters.html) diff --git a/docs/en/resources/tools/cockroachdb/cockroachdb-execute-sql.md b/docs/en/resources/tools/cockroachdb/cockroachdb-execute-sql.md new file mode 100644 index 0000000000..10a78926f4 --- /dev/null +++ b/docs/en/resources/tools/cockroachdb/cockroachdb-execute-sql.md @@ -0,0 +1,273 @@ +--- +title: "cockroachdb-execute-sql" +type: docs +weight: 1 +description: > + Execute ad-hoc SQL statements against a CockroachDB database. + +--- + +## About + +A `cockroachdb-execute-sql` tool executes ad-hoc SQL statements against a CockroachDB database. This tool is designed for interactive workflows where the SQL query is provided dynamically at runtime, making it ideal for developer assistance and exploratory data analysis. + +The tool takes a single `sql` parameter containing the SQL statement to execute and returns the query results. + +> **Note:** This tool is intended for developer assistant workflows with human-in-the-loop and shouldn't be used for production agents. For production use cases with predefined queries, use [cockroachdb-sql](./cockroachdb-sql.md) instead. + +## Example + +```yaml +sources: + my_cockroachdb: + type: cockroachdb + host: your-cluster.cockroachlabs.cloud + port: "26257" + user: myuser + password: mypassword + database: defaultdb + queryParams: + sslmode: require + +tools: + execute_sql: + type: cockroachdb-execute-sql + source: my_cockroachdb + description: Execute any SQL statement against the CockroachDB database +``` + +## Usage Examples + +### Simple SELECT Query + +```json +{ + "sql": "SELECT * FROM users LIMIT 10" +} +``` + +### Query with Aggregations + +```json +{ + "sql": "SELECT category, COUNT(*) as count, SUM(amount) as total FROM expenses GROUP BY category ORDER BY total DESC" +} +``` + +### Database Introspection + +```json +{ + "sql": "SHOW TABLES" +} +``` + +```json +{ + "sql": "SHOW COLUMNS FROM expenses" +} +``` + +### Multi-Region Information + +```json +{ + "sql": "SHOW REGIONS FROM DATABASE defaultdb" +} +``` + +```json +{ + "sql": "SHOW ZONE CONFIGURATIONS" +} +``` + +## CockroachDB-Specific Features + +### Check Cluster Version + +```json +{ + "sql": "SELECT version()" +} +``` + +### View Node Status + +```json +{ + "sql": "SELECT node_id, address, locality, is_live FROM crdb_internal.gossip_nodes" +} +``` + +### Check Replication Status + +```json +{ + "sql": "SELECT range_id, start_key, end_key, replicas, lease_holder FROM crdb_internal.ranges LIMIT 10" +} +``` + +### View Table Regions + +```json +{ + "sql": "SHOW REGIONS FROM TABLE expenses" +} +``` + +## Configuration + +### Required Fields + +| Field | Type | Description | +|-------|------|-------------| +| `type` | string | Must be `cockroachdb-execute-sql` | +| `source` | string | Name of the CockroachDB source to use | +| `description` | string | Human-readable description for the LLM | + +### Optional Fields + +| Field | Type | Description | +|-------|------|-------------| +| `authRequired` | array | List of authentication services required | + +## Parameters + +The tool accepts a single runtime parameter: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `sql` | string | The SQL statement to execute | + +## Best Practices + +### Use for Exploration, Not Production + +This tool is ideal for: +- Interactive database exploration +- Ad-hoc analysis and reporting +- Debugging and troubleshooting +- Schema inspection + +For production use cases, use [cockroachdb-sql](./cockroachdb-sql.md) with parameterized queries. + +### Be Cautious with Data Modification + +While this tool can execute any SQL statement, be careful with: +- `INSERT`, `UPDATE`, `DELETE` statements +- `DROP` or `ALTER` statements +- Schema changes in production + +### Use LIMIT for Large Results + +Always use `LIMIT` clauses when exploring data: + +```sql +SELECT * FROM large_table LIMIT 100 +``` + +### Leverage CockroachDB's SQL Extensions + +CockroachDB supports PostgreSQL syntax plus extensions: + +```sql +-- Show database survival goal +SHOW SURVIVAL GOAL FROM DATABASE defaultdb; + +-- View zone configurations +SHOW ZONE CONFIGURATION FOR TABLE expenses; + +-- Check table localities +SHOW CREATE TABLE expenses; +``` + +## Error Handling + +The tool will return descriptive errors for: +- **Syntax errors**: Invalid SQL syntax +- **Permission errors**: Insufficient user privileges +- **Connection errors**: Network or authentication issues +- **Runtime errors**: Constraint violations, type mismatches, etc. + +## Security Considerations + +### SQL Injection Risk + +Since this tool executes arbitrary SQL, it should only be used with: +- Trusted users in interactive sessions +- Human-in-the-loop workflows +- Development and testing environments + +Never expose this tool directly to end users without proper authorization controls. + +### Use Authentication + +Configure the `authRequired` field to restrict access: + +```yaml +tools: + execute_sql: + type: cockroachdb-execute-sql + source: my_cockroachdb + description: Execute SQL statements + authRequired: + - my-auth-service +``` + +### Read-Only Users + +For safer exploration, create read-only database users: + +```sql +CREATE USER readonly_user; +GRANT SELECT ON DATABASE defaultdb TO readonly_user; +``` + +## Common Use Cases + +### Database Administration + +```sql +-- View database size +SELECT + table_name, + pg_size_pretty(pg_total_relation_size(table_name::regclass)) AS size +FROM information_schema.tables +WHERE table_schema = 'public' +ORDER BY pg_total_relation_size(table_name::regclass) DESC; +``` + +### Performance Analysis + +```sql +-- Find slow queries +SELECT query, count, mean_latency +FROM crdb_internal.statement_statistics +WHERE mean_latency > INTERVAL '1 second' +ORDER BY mean_latency DESC +LIMIT 10; +``` + +### Data Quality Checks + +```sql +-- Find NULL values +SELECT COUNT(*) as null_count +FROM expenses +WHERE description IS NULL OR amount IS NULL; + +-- Find duplicates +SELECT user_id, email, COUNT(*) as count +FROM users +GROUP BY user_id, email +HAVING COUNT(*) > 1; +``` + +## See Also + +- [cockroachdb-sql](./cockroachdb-sql.md) - For parameterized, production-ready queries +- [cockroachdb-list-tables](./cockroachdb-list-tables.md) - List tables in the database +- [cockroachdb-list-schemas](./cockroachdb-list-schemas.md) - List database schemas +- [CockroachDB Source](../../sources/cockroachdb.md) - Source configuration reference +- [CockroachDB SQL Reference](https://www.cockroachlabs.com/docs/stable/sql-statements.html) - Official SQL documentation diff --git a/docs/en/resources/tools/cockroachdb/cockroachdb-list-schemas.md b/docs/en/resources/tools/cockroachdb/cockroachdb-list-schemas.md new file mode 100644 index 0000000000..8a9ee11292 --- /dev/null +++ b/docs/en/resources/tools/cockroachdb/cockroachdb-list-schemas.md @@ -0,0 +1,305 @@ +--- +title: "cockroachdb-list-schemas" +type: docs +weight: 1 +description: > + List schemas in a CockroachDB database. + +--- + +## About + +The `cockroachdb-list-schemas` tool retrieves a list of schemas (namespaces) in a CockroachDB database. Schemas are used to organize database objects such as tables, views, and functions into logical groups. + +This tool is useful for: +- Understanding database organization +- Discovering available schemas +- Multi-tenant application analysis +- Schema-level access control planning + +## Example + +```yaml +sources: + my_cockroachdb: + type: cockroachdb + host: your-cluster.cockroachlabs.cloud + port: "26257" + user: myuser + password: mypassword + database: defaultdb + queryParams: + sslmode: require + +tools: + list_schemas: + type: cockroachdb-list-schemas + source: my_cockroachdb + description: List all schemas in the database +``` + +## Configuration + +### Required Fields + +| Field | Type | Description | +|-------|------|-------------| +| `type` | string | Must be `cockroachdb-list-schemas` | +| `source` | string | Name of the CockroachDB source to use | +| `description` | string | Human-readable description for the LLM | + +### Optional Fields + +| Field | Type | Description | +|-------|------|-------------| +| `authRequired` | array | List of authentication services required | + +## Output Structure + +The tool returns a list of schemas with the following information: + +```json +[ + { + "catalog_name": "defaultdb", + "schema_name": "public", + "is_user_defined": true + }, + { + "catalog_name": "defaultdb", + "schema_name": "analytics", + "is_user_defined": true + } +] +``` + +### Fields + +| Field | Type | Description | +|-------|------|-------------| +| `catalog_name` | string | The database (catalog) name | +| `schema_name` | string | The schema name | +| `is_user_defined` | boolean | Whether this is a user-created schema (excludes system schemas) | + +## Usage Example + +```json +{} +``` + +No parameters are required. The tool automatically lists all user-defined schemas. + +## Default Schemas + +CockroachDB includes several standard schemas: + +- **`public`**: The default schema for user objects +- **`pg_catalog`**: PostgreSQL system catalog (excluded from results) +- **`information_schema`**: SQL standard metadata views (excluded from results) +- **`crdb_internal`**: CockroachDB internal metadata (excluded from results) +- **`pg_extension`**: PostgreSQL extension objects (excluded from results) + +The tool filters out system schemas and only returns user-defined schemas. + +## Schema Management in CockroachDB + +### Creating Schemas + +```sql +CREATE SCHEMA analytics; +``` + +### Using Schemas + +```sql +-- Create table in specific schema +CREATE TABLE analytics.revenue ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + amount DECIMAL(10,2), + date DATE +); + +-- Query from specific schema +SELECT * FROM analytics.revenue; +``` + +### Schema Search Path + +The search path determines which schemas are searched for unqualified object names: + +```sql +-- Show current search path +SHOW search_path; + +-- Set search path +SET search_path = analytics, public; +``` + +## Multi-Tenant Applications + +Schemas are commonly used for multi-tenant applications: + +```sql +-- Create schema per tenant +CREATE SCHEMA tenant_acme; +CREATE SCHEMA tenant_globex; + +-- Create same table structure in each schema +CREATE TABLE tenant_acme.orders (...); +CREATE TABLE tenant_globex.orders (...); +``` + +The `cockroachdb-list-schemas` tool helps discover all tenant schemas: + +```yaml +tools: + list_tenants: + type: cockroachdb-list-schemas + source: my_cockroachdb + description: | + List all tenant schemas in the database. + Each schema represents a separate tenant's data namespace. +``` + +## Best Practices + +### Use Schemas for Organization + +Group related tables into schemas: + +```sql +CREATE SCHEMA sales; +CREATE SCHEMA inventory; +CREATE SCHEMA hr; + +CREATE TABLE sales.orders (...); +CREATE TABLE inventory.products (...); +CREATE TABLE hr.employees (...); +``` + +### Schema Naming Conventions + +Use clear, descriptive schema names: +- Lowercase names +- Use underscores for multi-word names +- Avoid reserved keywords +- Use prefixes for grouped schemas (e.g., `tenant_`, `app_`) + +### Schema-Level Permissions + +Schemas enable fine-grained access control: + +```sql +-- Grant access to specific schema +GRANT USAGE ON SCHEMA analytics TO analyst_role; +GRANT SELECT ON ALL TABLES IN SCHEMA analytics TO analyst_role; + +-- Revoke access +REVOKE ALL ON SCHEMA hr FROM public; +``` + +## Integration with Other Tools + +### Combined with List Tables + +```yaml +tools: + list_schemas: + type: cockroachdb-list-schemas + source: my_cockroachdb + description: List all schemas first + + list_tables: + type: cockroachdb-list-tables + source: my_cockroachdb + description: | + List tables in the database. + Use list_schemas first to understand schema organization. +``` + +### Schema Discovery Workflow + +1. Call `cockroachdb-list-schemas` to discover schemas +2. Call `cockroachdb-list-tables` to see tables in each schema +3. Generate queries using fully qualified names: `schema.table` + +## Common Use Cases + +### Discover Database Structure + +```yaml +tools: + discover_schemas: + type: cockroachdb-list-schemas + source: my_cockroachdb + description: | + Discover how the database is organized into schemas. + Use this to understand the logical grouping of tables. +``` + +### Multi-Tenant Analysis + +```yaml +tools: + list_tenant_schemas: + type: cockroachdb-list-schemas + source: my_cockroachdb + description: | + List all tenant schemas (each tenant has their own schema). + Schema names follow the pattern: tenant_ +``` + +### Schema Migration Planning + +```yaml +tools: + audit_schemas: + type: cockroachdb-list-schemas + source: my_cockroachdb + description: | + Audit existing schemas before migration. + Identifies all schemas that need to be migrated. +``` + +## Error Handling + +The tool handles common errors: +- **Connection errors**: Returns connection failure details +- **Permission errors**: Returns error if user lacks USAGE privilege +- **Empty results**: Returns empty array if no user schemas exist + +## Permissions Required + +To list schemas, the user needs: +- `CONNECT` privilege on the database +- No specific schema privileges required for listing + +To query objects within schemas, the user needs: +- `USAGE` privilege on the schema +- Appropriate object privileges (SELECT, INSERT, etc.) + +## CockroachDB-Specific Features + +### System Schemas + +CockroachDB includes PostgreSQL-compatible system schemas plus CockroachDB-specific ones: + +- `crdb_internal.*`: CockroachDB internal metadata and statistics +- `pg_catalog.*`: PostgreSQL system catalog +- `information_schema.*`: SQL standard information schema + +These are automatically filtered from the results. + +### User-Defined Flag + +The `is_user_defined` field helps distinguish: +- `true`: User-created schemas +- `false`: System schemas (already filtered out) + +## See Also + +- [cockroachdb-sql](./cockroachdb-sql.md) - Execute parameterized queries +- [cockroachdb-execute-sql](./cockroachdb-execute-sql.md) - Execute ad-hoc SQL +- [cockroachdb-list-tables](./cockroachdb-list-tables.md) - List tables in the database +- [CockroachDB Source](../../sources/cockroachdb.md) - Source configuration reference +- [CockroachDB Schema Design](https://www.cockroachlabs.com/docs/stable/schema-design-overview.html) - Official documentation diff --git a/docs/en/resources/tools/cockroachdb/cockroachdb-list-tables.md b/docs/en/resources/tools/cockroachdb/cockroachdb-list-tables.md new file mode 100644 index 0000000000..339dbd2320 --- /dev/null +++ b/docs/en/resources/tools/cockroachdb/cockroachdb-list-tables.md @@ -0,0 +1,344 @@ +--- +title: "cockroachdb-list-tables" +type: docs +weight: 1 +description: > + List tables in a CockroachDB database with schema details. + +--- + +## About + +The `cockroachdb-list-tables` tool retrieves a list of tables from a CockroachDB database. It provides detailed information about table structure, including columns, constraints, indexes, and foreign key relationships. + +This tool is useful for: +- Database schema discovery +- Understanding table relationships +- Generating context for AI-powered database queries +- Documentation and analysis + +## Example + +```yaml +sources: + my_cockroachdb: + type: cockroachdb + host: your-cluster.cockroachlabs.cloud + port: "26257" + user: myuser + password: mypassword + database: defaultdb + queryParams: + sslmode: require + +tools: + list_all_tables: + type: cockroachdb-list-tables + source: my_cockroachdb + description: List all user tables in the database with their structure +``` + +## Configuration + +### Required Fields + +| Field | Type | Description | +|-------|------|-------------| +| `type` | string | Must be `cockroachdb-list-tables` | +| `source` | string | Name of the CockroachDB source to use | +| `description` | string | Human-readable description for the LLM | + +### Optional Fields + +| Field | Type | Description | +|-------|------|-------------| +| `authRequired` | array | List of authentication services required | + +## Parameters + +The tool accepts optional runtime parameters: + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `table_names` | array | all tables | List of specific table names to retrieve | +| `output_format` | string | "detailed" | Output format: "simple" or "detailed" | + +## Output Formats + +### Simple Format + +Returns basic table information: +- Table name +- Row count estimate +- Size information + +```json +{ + "table_names": ["users"], + "output_format": "simple" +} +``` + +### Detailed Format (Default) + +Returns comprehensive table information: +- Table name and schema +- All columns with types and constraints +- Primary keys +- Foreign keys and relationships +- Indexes +- Check constraints +- Table size and row counts + +```json +{ + "table_names": ["users", "orders"], + "output_format": "detailed" +} +``` + +## Usage Examples + +### List All Tables + +```json +{} +``` + +### List Specific Tables + +```json +{ + "table_names": ["users", "orders", "expenses"] +} +``` + +### Simple Output + +```json +{ + "output_format": "simple" +} +``` + +## Output Structure + +### Simple Format Output + +```json +{ + "table_name": "users", + "estimated_rows": 1000, + "size": "128 KB" +} +``` + +### Detailed Format Output + +```json +{ + "table_name": "users", + "schema": "public", + "columns": [ + { + "name": "id", + "type": "UUID", + "nullable": false, + "default": "gen_random_uuid()" + }, + { + "name": "email", + "type": "STRING", + "nullable": false, + "default": null + }, + { + "name": "created_at", + "type": "TIMESTAMP", + "nullable": false, + "default": "now()" + } + ], + "primary_key": ["id"], + "indexes": [ + { + "name": "users_pkey", + "columns": ["id"], + "unique": true, + "primary": true + }, + { + "name": "users_email_idx", + "columns": ["email"], + "unique": true, + "primary": false + } + ], + "foreign_keys": [], + "constraints": [ + { + "name": "users_email_check", + "type": "CHECK", + "definition": "email ~* '^[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Z|a-z]{2,}$'" + } + ] +} +``` + +## CockroachDB-Specific Information + +### UUID Primary Keys + +The tool recognizes CockroachDB's recommended UUID primary key pattern: + +```sql +CREATE TABLE users ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + ... +); +``` + +### Multi-Region Tables + +For multi-region tables, the output includes locality information: + +```json +{ + "table_name": "users", + "locality": "REGIONAL BY ROW", + "regions": ["us-east-1", "us-west-2", "eu-west-1"] +} +``` + +### Interleaved Tables + +The tool shows parent-child relationships for interleaved tables (legacy feature): + +```json +{ + "table_name": "order_items", + "interleaved_in": "orders" +} +``` + +## Best Practices + +### Use for Schema Discovery + +The tool is ideal for helping AI assistants understand your database structure: + +```yaml +tools: + discover_schema: + type: cockroachdb-list-tables + source: my_cockroachdb + description: | + Use this tool first to understand the database schema before generating queries. + It shows all tables, their columns, data types, and relationships. +``` + +### Filter Large Schemas + +For databases with many tables, specify relevant tables: + +```json +{ + "table_names": ["users", "orders", "products"], + "output_format": "detailed" +} +``` + +### Use Simple Format for Overviews + +When you need just table names and sizes: + +```json +{ + "output_format": "simple" +} +``` + +## Excluded Tables + +The tool automatically excludes system tables and schemas: +- `pg_catalog.*` - PostgreSQL system catalog +- `information_schema.*` - SQL standard information schema +- `crdb_internal.*` - CockroachDB internal tables +- `pg_extension.*` - PostgreSQL extension tables + +Only user-created tables in the public schema (and other user schemas) are returned. + +## Error Handling + +The tool handles common errors: +- **Table not found**: Returns empty result for non-existent tables +- **Permission errors**: Returns error if user lacks SELECT privileges +- **Connection errors**: Returns connection failure details + +## Integration with AI Assistants + +### Prompt Example + +```yaml +tools: + list_tables: + type: cockroachdb-list-tables + source: my_cockroachdb + description: | + Lists all tables in the database with detailed schema information. + Use this tool to understand: + - What tables exist + - What columns each table has + - Data types and constraints + - Relationships between tables (foreign keys) + - Available indexes + + Always call this tool before generating SQL queries to ensure + you use correct table and column names. +``` + +## Common Use Cases + +### Generate Context for Queries + +```json +{} +``` + +This provides comprehensive schema information that helps AI assistants generate accurate SQL queries. + +### Analyze Table Structure + +```json +{ + "table_names": ["users"], + "output_format": "detailed" +} +``` + +Perfect for understanding a specific table's structure, constraints, and relationships. + +### Quick Schema Overview + +```json +{ + "output_format": "simple" +} +``` + +Gets a quick list of tables with basic statistics. + +## Performance Considerations + +- **Simple format** is faster for large databases +- **Detailed format** queries system tables extensively +- Specifying `table_names` reduces query time +- Results are fetched in a single query for efficiency + +## See Also + +- [cockroachdb-sql](./cockroachdb-sql.md) - Execute parameterized queries +- [cockroachdb-execute-sql](./cockroachdb-execute-sql.md) - Execute ad-hoc SQL +- [cockroachdb-list-schemas](./cockroachdb-list-schemas.md) - List database schemas +- [CockroachDB Source](../../sources/cockroachdb.md) - Source configuration reference +- [CockroachDB Schema Design](https://www.cockroachlabs.com/docs/stable/schema-design-overview.html) - Best practices diff --git a/docs/en/resources/tools/cockroachdb/cockroachdb-sql.md b/docs/en/resources/tools/cockroachdb/cockroachdb-sql.md new file mode 100644 index 0000000000..aa31edcd52 --- /dev/null +++ b/docs/en/resources/tools/cockroachdb/cockroachdb-sql.md @@ -0,0 +1,291 @@ +--- +title: "cockroachdb-sql" +type: docs +weight: 1 +description: > + Execute parameterized SQL queries in CockroachDB. + +--- + +## About + +The `cockroachdb-sql` tool allows you to execute parameterized SQL queries against a CockroachDB database. This tool supports prepared statements with parameter binding, template parameters for dynamic query construction, and automatic transaction retry for resilience against serialization conflicts. + +## Example + +```yaml +sources: + my_cockroachdb: + type: cockroachdb + host: your-cluster.cockroachlabs.cloud + port: "26257" + user: myuser + password: mypassword + database: defaultdb + queryParams: + sslmode: require + +tools: + get_user_orders: + type: cockroachdb-sql + source: my_cockroachdb + description: Get all orders for a specific user + statement: | + SELECT o.id, o.order_date, o.total_amount, o.status + FROM orders o + WHERE o.user_id = $1 + ORDER BY o.order_date DESC + parameters: + - name: user_id + type: string + description: The UUID of the user +``` + +## Configuration + +### Required Fields + +| Field | Type | Description | +|-------|------|-------------| +| `type` | string | Must be `cockroachdb-sql` | +| `source` | string | Name of the CockroachDB source to use | +| `description` | string | Human-readable description of what the tool does | +| `statement` | string | The SQL query to execute | + +### Optional Fields + +| Field | Type | Description | +|-------|------|-------------| +| `parameters` | array | List of parameter definitions for the query | +| `templateParameters` | array | List of template parameters for dynamic query construction | +| `authRequired` | array | List of authentication services required | + +## Parameters + +Parameters allow you to safely pass values into your SQL queries using prepared statements. CockroachDB uses PostgreSQL-style parameter placeholders: `$1`, `$2`, etc. + +### Parameter Types + +- `string`: Text values +- `number`: Numeric values (integers or decimals) +- `boolean`: True/false values +- `array`: Array of values + +### Example with Multiple Parameters + +```yaml +tools: + filter_expenses: + type: cockroachdb-sql + source: my_cockroachdb + description: Filter expenses by category and date range + statement: | + SELECT id, description, amount, category, expense_date + FROM expenses + WHERE user_id = $1 + AND category = $2 + AND expense_date >= $3 + AND expense_date <= $4 + ORDER BY expense_date DESC + parameters: + - name: user_id + type: string + description: The user's UUID + - name: category + type: string + description: Expense category (e.g., "Food", "Transport") + - name: start_date + type: string + description: Start date in YYYY-MM-DD format + - name: end_date + type: string + description: End date in YYYY-MM-DD format +``` + +## Template Parameters + +Template parameters enable dynamic query construction by replacing placeholders in the SQL statement before parameter binding. This is useful for dynamic table names, column names, or query structure. + +### Example with Template Parameters + +```yaml +tools: + get_column_data: + type: cockroachdb-sql + source: my_cockroachdb + description: Get data from a specific column + statement: | + SELECT {{column_name}} + FROM {{table_name}} + WHERE user_id = $1 + LIMIT 100 + templateParameters: + - name: table_name + type: string + description: The table to query + - name: column_name + type: string + description: The column to retrieve + parameters: + - name: user_id + type: string + description: The user's UUID +``` + +## Best Practices + +### Use UUID Primary Keys + +CockroachDB performs best with UUID primary keys to avoid transaction hotspots: + +```sql +CREATE TABLE orders ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL, + order_date TIMESTAMP DEFAULT now(), + total_amount DECIMAL(10,2) +); +``` + +### Use Indexes for Performance + +Create indexes on frequently queried columns: + +```sql +CREATE INDEX idx_orders_user_id ON orders(user_id); +CREATE INDEX idx_orders_date ON orders(order_date DESC); +``` + +### Use JOINs Efficiently + +CockroachDB supports standard SQL JOINs. Keep joins efficient by: +- Adding appropriate indexes +- Using UUIDs for foreign keys +- Limiting result sets with WHERE clauses + +```yaml +tools: + get_user_with_orders: + type: cockroachdb-sql + source: my_cockroachdb + description: Get user details with their recent orders + statement: | + SELECT u.name, u.email, o.id as order_id, o.order_date, o.total_amount + FROM users u + LEFT JOIN orders o ON u.id = o.user_id + WHERE u.id = $1 + ORDER BY o.order_date DESC + LIMIT 10 + parameters: + - name: user_id + type: string + description: The user's UUID +``` + +### Handle NULL Values + +Use COALESCE or NULL checks when dealing with nullable columns: + +```sql +SELECT id, description, COALESCE(notes, 'No notes') as notes +FROM expenses +WHERE user_id = $1 +``` + +## Error Handling + +The tool automatically handles: +- **Connection errors**: Retried with exponential backoff +- **Serialization conflicts**: Automatically retried using cockroach-go library +- **Invalid parameters**: Returns descriptive error messages +- **SQL syntax errors**: Returns database error details + +## Advanced Usage + +### Aggregations + +```yaml +tools: + expense_summary: + type: cockroachdb-sql + source: my_cockroachdb + description: Get expense summary by category for a user + statement: | + SELECT + category, + COUNT(*) as count, + SUM(amount) as total_amount, + AVG(amount) as avg_amount + FROM expenses + WHERE user_id = $1 + AND expense_date >= $2 + GROUP BY category + ORDER BY total_amount DESC + parameters: + - name: user_id + type: string + description: The user's UUID + - name: start_date + type: string + description: Start date in YYYY-MM-DD format +``` + +### Window Functions + +```yaml +tools: + running_total: + type: cockroachdb-sql + source: my_cockroachdb + description: Get running total of expenses + statement: | + SELECT + expense_date, + amount, + SUM(amount) OVER (ORDER BY expense_date) as running_total + FROM expenses + WHERE user_id = $1 + ORDER BY expense_date + parameters: + - name: user_id + type: string + description: The user's UUID +``` + +### Common Table Expressions (CTEs) + +```yaml +tools: + top_spenders: + type: cockroachdb-sql + source: my_cockroachdb + description: Find top spending users + statement: | + WITH user_totals AS ( + SELECT + user_id, + SUM(amount) as total_spent + FROM expenses + WHERE expense_date >= $1 + GROUP BY user_id + ) + SELECT + u.name, + u.email, + ut.total_spent + FROM user_totals ut + JOIN users u ON ut.user_id = u.id + ORDER BY ut.total_spent DESC + LIMIT 10 + parameters: + - name: start_date + type: string + description: Start date in YYYY-MM-DD format +``` + +## See Also + +- [cockroachdb-execute-sql](./cockroachdb-execute-sql.md) - For ad-hoc SQL execution +- [cockroachdb-list-tables](./cockroachdb-list-tables.md) - List tables in the database +- [cockroachdb-list-schemas](./cockroachdb-list-schemas.md) - List database schemas +- [CockroachDB Source](../../sources/cockroachdb.md) - Source configuration reference diff --git a/go.mod b/go.mod index f74a7dc2c6..9fd4617e1b 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.30.0 github.com/apache/cassandra-gocql-driver/v2 v2.0.0 github.com/cenkalti/backoff/v5 v5.0.3 + github.com/cockroachdb/cockroach-go/v2 v2.4.2 github.com/couchbase/gocb/v2 v2.11.1 github.com/couchbase/tools-common/http v1.0.9 github.com/elastic/elastic-transport-go/v8 v8.8.0 diff --git a/go.sum b/go.sum index c3f59c72d4..8ebebe11d5 100644 --- a/go.sum +++ b/go.sum @@ -800,6 +800,8 @@ github.com/cncf/xds/go v0.0.0-20230105202645-06c439db220b/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20230607035331-e9ce68804cb4/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443 h1:aQ3y1lwWyqYPiWZThqv1aFbZMiM9vblcSArJRf2Irls= github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= +github.com/cockroachdb/cockroach-go/v2 v2.4.2 h1:QB0ozDWQUUJ0GP8Zw63X/qHefPTCpLvtfCs6TLrPgyE= +github.com/cockroachdb/cockroach-go/v2 v2.4.2/go.mod h1:9U179XbCx4qFWtNhc7BiWLPfuyMVQ7qdAhfrwLz1vH0= github.com/containerd/continuity v0.4.5 h1:ZRoN1sXq9u7V6QoHMcVWGhOwDFqZ4B9i5H6un1Wh0x4= github.com/containerd/continuity v0.4.5/go.mod h1:/lNJvtJKUQStBzpVQ1+rasXO1LAWtUQssk28EZvJ3nE= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= @@ -960,6 +962,8 @@ github.com/godror/godror v0.49.6 h1:ts4ZGw8uLJ42e1D7aXmVuSrld0/lzUzmIUjuUuQOgGM= github.com/godror/godror v0.49.6/go.mod h1:kTMcxZzRw73RT5kn9v3JkBK4kHI6dqowHotqV72ebU8= github.com/godror/knownpb v0.3.0 h1:+caUdy8hTtl7X05aPl3tdL540TvCcaQA6woZQroLZMw= github.com/godror/knownpb v0.3.0/go.mod h1:PpTyfJwiOEAzQl7NtVCM8kdPCnp3uhxsZYIzZ5PV4zU= +github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E= +github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= diff --git a/internal/sources/cockroachdb/cockroachdb.go b/internal/sources/cockroachdb/cockroachdb.go new file mode 100644 index 0000000000..5a90fcf53e --- /dev/null +++ b/internal/sources/cockroachdb/cockroachdb.go @@ -0,0 +1,430 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdb + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "math" + "net/url" + "regexp" + "strings" + "time" + + crdbpgx "github.com/cockroachdb/cockroach-go/v2/crdb/crdbpgxv5" + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/util" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "go.opentelemetry.io/otel/trace" +) + +const SourceKind string = "cockroachdb" +const SourceType string = "cockroachdb" + +var _ sources.SourceConfig = Config{} + +func init() { + if !sources.Register(SourceKind, newConfig) { + panic(fmt.Sprintf("source kind %q already registered", SourceKind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) { + // MCP compliance: Read-only by default, require explicit opt-in for writes + actual := Config{ + Name: name, + MaxRetries: 5, + RetryBaseDelay: "500ms", + ReadOnlyMode: true, // MCP requirement: read-only by default + EnableWriteMode: false, // Must be explicitly enabled + MaxRowLimit: 1000, // MCP requirement: limit query results + QueryTimeoutSec: 30, // MCP requirement: prevent long-running queries + EnableTelemetry: true, // MCP requirement: observability + TelemetryVerbose: false, + } + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + + // Security validation: If EnableWriteMode is true, ReadOnlyMode should be false + if actual.EnableWriteMode { + actual.ReadOnlyMode = false + } + + return actual, nil +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Host string `yaml:"host" validate:"required"` + Port string `yaml:"port" validate:"required"` + User string `yaml:"user" validate:"required"` + Password string `yaml:"password"` + Database string `yaml:"database" validate:"required"` + QueryParams map[string]string `yaml:"queryParams"` + MaxRetries int `yaml:"maxRetries"` + RetryBaseDelay string `yaml:"retryBaseDelay"` + + // MCP Security Features + ReadOnlyMode bool `yaml:"readOnlyMode"` // Default: true (enforced in Initialize) + EnableWriteMode bool `yaml:"enableWriteMode"` // Explicit opt-in for write operations + MaxRowLimit int `yaml:"maxRowLimit"` // Default: 1000 + QueryTimeoutSec int `yaml:"queryTimeoutSec"` // Default: 30 + + // Observability + EnableTelemetry bool `yaml:"enableTelemetry"` // Default: true + TelemetryVerbose bool `yaml:"telemetryVerbose"` // Default: false + ClusterID string `yaml:"clusterID"` // Optional cluster identifier for telemetry +} + +func (r Config) SourceConfigKind() string { + return SourceKind +} + +func (r Config) SourceConfigType() string { + return SourceType +} + +func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) { + retryBaseDelay, err := time.ParseDuration(r.RetryBaseDelay) + if err != nil { + return nil, fmt.Errorf("invalid retryBaseDelay: %w", err) + } + + pool, err := initCockroachDBConnectionPoolWithRetry(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Database, r.QueryParams, r.MaxRetries, retryBaseDelay) + if err != nil { + return nil, fmt.Errorf("unable to create pool: %w", err) + } + + s := &Source{ + Config: r, + Pool: pool, + } + return s, nil +} + +var _ sources.Source = &Source{} + +type Source struct { + Config + Pool *pgxpool.Pool +} + +func (s *Source) SourceKind() string { + return SourceKind +} + +func (s *Source) SourceType() string { + return SourceType +} + +func (s *Source) ToConfig() sources.SourceConfig { + return s.Config +} + +func (s *Source) CockroachDBPool() *pgxpool.Pool { + return s.Pool +} + +func (s *Source) PostgresPool() *pgxpool.Pool { + return s.Pool +} + +// ExecuteTxWithRetry executes a function within a transaction with automatic retry logic +// using the official CockroachDB retry mechanism from cockroach-go/v2 +func (s *Source) ExecuteTxWithRetry(ctx context.Context, fn func(pgx.Tx) error) error { + return crdbpgx.ExecuteTx(ctx, s.Pool, pgx.TxOptions{}, fn) +} + +// Query executes a query using the connection pool with MCP security enforcement. +// For read-only queries, connection-level retry is sufficient. +// For write operations requiring transaction retry, use ExecuteTxWithRetry directly. +// Note: Callers should manage context timeouts as needed. +func (s *Source) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { + // MCP Security Check 1: Enforce write operation restrictions + if err := s.CanExecuteWrite(sql); err != nil { + return nil, err + } + + // MCP Security Check 2: Apply query limits (row limit) + modifiedSQL, err := s.ApplyQueryLimits(sql) + if err != nil { + return nil, err + } + + return s.Pool.Query(ctx, modifiedSQL, args...) +} + +// ============================================================================ +// MCP Security & Observability Features +// ============================================================================ + +// TelemetryEvent represents a structured telemetry event for MCP tool calls +type TelemetryEvent struct { + Timestamp time.Time `json:"timestamp"` + ToolName string `json:"tool_name"` + ClusterID string `json:"cluster_id"` + Database string `json:"database"` + User string `json:"user"` + SQLRedacted string `json:"sql_redacted"` // Query with values redacted + Status string `json:"status"` // "success" | "failure" + ErrorCode string `json:"error_code,omitempty"` + ErrorMsg string `json:"error_msg,omitempty"` + LatencyMs int64 `json:"latency_ms"` + RowsAffected int64 `json:"rows_affected,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// StructuredError represents an MCP-compliant error with error codes +type StructuredError struct { + Code string `json:"error_code"` + Message string `json:"message"` + Details map[string]any `json:"details,omitempty"` +} + +func (e *StructuredError) Error() string { + return fmt.Sprintf("[%s] %s", e.Code, e.Message) +} + +// MCP Error Codes +const ( + ErrCodeUnauthorized = "CRDB_UNAUTHORIZED" + ErrCodeReadOnlyViolation = "CRDB_READONLY_VIOLATION" + ErrCodeQueryTimeout = "CRDB_QUERY_TIMEOUT" + ErrCodeRowLimitExceeded = "CRDB_ROW_LIMIT_EXCEEDED" + ErrCodeInvalidSQL = "CRDB_INVALID_SQL" + ErrCodeConnectionFailed = "CRDB_CONNECTION_FAILED" + ErrCodeWriteModeRequired = "CRDB_WRITE_MODE_REQUIRED" + ErrCodeQueryExecutionFailed = "CRDB_QUERY_EXECUTION_FAILED" +) + +// SQLStatementType represents the type of SQL statement +type SQLStatementType int + +const ( + SQLTypeUnknown SQLStatementType = iota + SQLTypeSelect + SQLTypeInsert + SQLTypeUpdate + SQLTypeDelete + SQLTypeDDL // CREATE, ALTER, DROP + SQLTypeTruncate + SQLTypeExplain + SQLTypeShow + SQLTypeSet +) + +// ClassifySQL analyzes a SQL statement and returns its type +func ClassifySQL(sql string) SQLStatementType { + // Normalize: trim and convert to uppercase for analysis + normalized := strings.TrimSpace(strings.ToUpper(sql)) + + if normalized == "" { + return SQLTypeUnknown + } + + // Remove comments + normalized = regexp.MustCompile(`--.*`).ReplaceAllString(normalized, "") + normalized = regexp.MustCompile(`/\*.*?\*/`).ReplaceAllString(normalized, "") + normalized = strings.TrimSpace(normalized) + + // Check statement type + switch { + case strings.HasPrefix(normalized, "SELECT"): + return SQLTypeSelect + case strings.HasPrefix(normalized, "INSERT"): + return SQLTypeInsert + case strings.HasPrefix(normalized, "UPDATE"): + return SQLTypeUpdate + case strings.HasPrefix(normalized, "DELETE"): + return SQLTypeDelete + case strings.HasPrefix(normalized, "TRUNCATE"): + return SQLTypeTruncate + case strings.HasPrefix(normalized, "CREATE"): + return SQLTypeDDL + case strings.HasPrefix(normalized, "ALTER"): + return SQLTypeDDL + case strings.HasPrefix(normalized, "DROP"): + return SQLTypeDDL + case strings.HasPrefix(normalized, "EXPLAIN"): + return SQLTypeExplain + case strings.HasPrefix(normalized, "SHOW"): + return SQLTypeShow + case strings.HasPrefix(normalized, "SET"): + return SQLTypeSet + default: + return SQLTypeUnknown + } +} + +// IsWriteOperation returns true if the SQL statement modifies data +func IsWriteOperation(sqlType SQLStatementType) bool { + switch sqlType { + case SQLTypeInsert, SQLTypeUpdate, SQLTypeDelete, SQLTypeTruncate, SQLTypeDDL: + return true + default: + return false + } +} + +// IsReadOnlyMode returns whether the source is in read-only mode +func (s *Source) IsReadOnlyMode() bool { + return s.ReadOnlyMode && !s.EnableWriteMode +} + +// CanExecuteWrite checks if a write operation is allowed +func (s *Source) CanExecuteWrite(sql string) error { + sqlType := ClassifySQL(sql) + + if IsWriteOperation(sqlType) && s.IsReadOnlyMode() { + return &StructuredError{ + Code: ErrCodeReadOnlyViolation, + Message: "Write operations are not allowed in read-only mode. Set enableWriteMode: true to allow writes.", + Details: map[string]any{ + "sql_type": sqlType, + "read_only_mode": s.ReadOnlyMode, + "enable_write_mode": s.EnableWriteMode, + }, + } + } + + return nil +} + +// ApplyQueryLimits applies row limits to a SQL query for MCP security compliance. +// Context timeout management is the responsibility of the caller (following Go best practices). +// Returns potentially modified SQL with LIMIT clause for SELECT queries. +func (s *Source) ApplyQueryLimits(sql string) (string, error) { + sqlType := ClassifySQL(sql) + + // Apply row limit only to SELECT queries + if sqlType == SQLTypeSelect && s.MaxRowLimit > 0 { + // Check if query already has LIMIT clause + normalized := strings.ToUpper(sql) + if !strings.Contains(normalized, " LIMIT ") { + // Add LIMIT clause - trim trailing whitespace and semicolon + sql = strings.TrimSpace(sql) + sql = strings.TrimSuffix(sql, ";") + sql = fmt.Sprintf("%s LIMIT %d", sql, s.MaxRowLimit) + } + } + + return sql, nil +} + +// RedactSQL redacts sensitive values from SQL for telemetry +func RedactSQL(sql string) string { + // Redact string literals + sql = regexp.MustCompile(`'[^']*'`).ReplaceAllString(sql, "'***'") + + // Redact numbers that might be sensitive + sql = regexp.MustCompile(`\b\d{10,}\b`).ReplaceAllString(sql, "***") + + return sql +} + +// EmitTelemetry logs a telemetry event in structured JSON format +func (s *Source) EmitTelemetry(ctx context.Context, event TelemetryEvent) { + if !s.EnableTelemetry { + return + } + + // Set cluster ID if not already set + if event.ClusterID == "" { + event.ClusterID = s.ClusterID + if event.ClusterID == "" { + event.ClusterID = s.Database // Fallback to database name + } + } + + // Set database and user + if event.Database == "" { + event.Database = s.Database + } + if event.User == "" { + event.User = s.User + } + + // Log as structured JSON + if s.TelemetryVerbose { + jsonBytes, _ := json.Marshal(event) + slog.Info("CockroachDB MCP Telemetry", "event", string(jsonBytes)) + } else { + // Minimal logging + slog.Info("CockroachDB MCP", + "tool", event.ToolName, + "status", event.Status, + "latency_ms", event.LatencyMs, + "error_code", event.ErrorCode, + ) + } +} + +func initCockroachDBConnectionPoolWithRetry(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string, maxRetries int, baseDelay time.Duration) (*pgxpool.Pool, error) { + //nolint:all + ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) + defer span.End() + + userAgent, err := util.UserAgentFromContext(ctx) + if err != nil { + userAgent = "genai-toolbox" + } + if queryParams == nil { + queryParams = make(map[string]string) + } + if _, ok := queryParams["application_name"]; !ok { + queryParams["application_name"] = userAgent + } + + connURL := &url.URL{ + Scheme: "postgres", + User: url.UserPassword(user, pass), + Host: fmt.Sprintf("%s:%s", host, port), + Path: dbname, + RawQuery: ConvertParamMapToRawQuery(queryParams), + } + + var pool *pgxpool.Pool + for attempt := 0; attempt <= maxRetries; attempt++ { + pool, err = pgxpool.New(ctx, connURL.String()) + if err == nil { + err = pool.Ping(ctx) + } + + if err == nil { + return pool, nil + } + + if attempt < maxRetries { + backoff := baseDelay * time.Duration(math.Pow(2, float64(attempt))) + time.Sleep(backoff) + } + } + + return nil, fmt.Errorf("failed to connect to CockroachDB after %d retries: %w", maxRetries, err) +} + +func ConvertParamMapToRawQuery(queryParams map[string]string) string { + values := url.Values{} + for k, v := range queryParams { + values.Add(k, v) + } + return values.Encode() +} diff --git a/internal/sources/cockroachdb/cockroachdb_test.go b/internal/sources/cockroachdb/cockroachdb_test.go new file mode 100644 index 0000000000..db14542c13 --- /dev/null +++ b/internal/sources/cockroachdb/cockroachdb_test.go @@ -0,0 +1,224 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdb + +import ( + "context" + "strings" + "testing" + + "github.com/goccy/go-yaml" +) + +func TestCockroachDBSourceConfig(t *testing.T) { + tests := []struct { + name string + yaml string + }{ + { + name: "valid config", + yaml: ` +name: test-cockroachdb +type: cockroachdb +host: localhost +port: "26257" +user: root +password: "" +database: defaultdb +maxRetries: 5 +retryBaseDelay: 500ms +queryParams: + sslmode: disable +`, + }, + { + name: "with optional queryParams", + yaml: ` +name: test-cockroachdb +type: cockroachdb +host: localhost +port: "26257" +user: root +password: testpass +database: testdb +queryParams: + sslmode: require + sslcert: /path/to/cert +`, + }, + { + name: "with custom retry settings", + yaml: ` +name: test-cockroachdb +type: cockroachdb +host: localhost +port: "26257" +user: root +password: "" +database: defaultdb +maxRetries: 10 +retryBaseDelay: 1s +`, + }, + { + name: "without password (insecure mode)", + yaml: ` +name: test-cockroachdb +type: cockroachdb +host: localhost +port: "26257" +user: root +database: defaultdb +`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + decoder := yaml.NewDecoder(strings.NewReader(tt.yaml)) + cfg, err := newConfig(context.Background(), "test", decoder) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if cfg == nil { + t.Fatal("expected config but got nil") + } + + // Verify it's the right type + cockroachCfg, ok := cfg.(Config) + if !ok { + t.Fatalf("expected Config type, got %T", cfg) + } + + // Verify SourceConfigType + if cockroachCfg.SourceConfigType() != SourceType { + t.Errorf("expected SourceConfigType %q, got %q", SourceType, cockroachCfg.SourceConfigType()) + } + + t.Logf("✅ Config parsed successfully: %+v", cockroachCfg) + }) + } +} + +func TestCockroachDBSourceType(t *testing.T) { + yamlContent := ` +name: test-cockroachdb +type: cockroachdb +host: localhost +port: "26257" +user: root +password: "" +database: defaultdb +` + decoder := yaml.NewDecoder(strings.NewReader(yamlContent)) + cfg, err := newConfig(context.Background(), "test", decoder) + if err != nil { + t.Fatalf("failed to create config: %v", err) + } + + if cfg.SourceConfigType() != "cockroachdb" { + t.Errorf("expected SourceConfigType 'cockroachdb', got %q", cfg.SourceConfigType()) + } +} + +func TestCockroachDBDefaultValues(t *testing.T) { + yamlContent := ` +name: test-cockroachdb +type: cockroachdb +host: localhost +port: "26257" +user: root +password: "" +database: defaultdb +` + decoder := yaml.NewDecoder(strings.NewReader(yamlContent)) + cfg, err := newConfig(context.Background(), "test", decoder) + if err != nil { + t.Fatalf("failed to create config: %v", err) + } + + cockroachCfg, ok := cfg.(Config) + if !ok { + t.Fatalf("expected Config type") + } + + // Check default values + if cockroachCfg.MaxRetries != 5 { + t.Errorf("expected default MaxRetries 5, got %d", cockroachCfg.MaxRetries) + } + + if cockroachCfg.RetryBaseDelay != "500ms" { + t.Errorf("expected default RetryBaseDelay '500ms', got %q", cockroachCfg.RetryBaseDelay) + } + + t.Logf("✅ Default values set correctly") +} + +func TestConvertParamMapToRawQuery(t *testing.T) { + tests := []struct { + name string + params map[string]string + want []string // Expected substrings in any order + }{ + { + name: "empty params", + params: map[string]string{}, + want: []string{}, + }, + { + name: "single param", + params: map[string]string{ + "sslmode": "disable", + }, + want: []string{"sslmode=disable"}, + }, + { + name: "multiple params", + params: map[string]string{ + "sslmode": "require", + "application_name": "test-app", + }, + want: []string{"sslmode=require", "application_name=test-app"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertParamMapToRawQuery(tt.params) + + if len(tt.want) == 0 { + if result != "" { + t.Errorf("expected empty string, got %q", result) + } + return + } + + // Check that all expected substrings are in the result + for _, want := range tt.want { + if !contains(result, want) { + t.Errorf("expected result to contain %q, got %q", want, result) + } + } + + t.Logf("✅ Query string: %s", result) + }) + } +} + +func contains(s, substr string) bool { + return strings.Contains(s, substr) +} diff --git a/internal/sources/cockroachdb/security_test.go b/internal/sources/cockroachdb/security_test.go new file mode 100644 index 0000000000..c80ad5b36e --- /dev/null +++ b/internal/sources/cockroachdb/security_test.go @@ -0,0 +1,455 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdb + +import ( + "context" + "strings" + "testing" + "time" + + yaml "github.com/goccy/go-yaml" +) + +// TestClassifySQL tests SQL statement classification +func TestClassifySQL(t *testing.T) { + tests := []struct { + name string + sql string + expected SQLStatementType + }{ + {"SELECT", "SELECT * FROM users", SQLTypeSelect}, + {"SELECT with spaces", " SELECT * FROM users ", SQLTypeSelect}, + {"SELECT with comment", "-- comment\nSELECT * FROM users", SQLTypeSelect}, + {"INSERT", "INSERT INTO users (name) VALUES ('alice')", SQLTypeInsert}, + {"UPDATE", "UPDATE users SET name='bob' WHERE id=1", SQLTypeUpdate}, + {"DELETE", "DELETE FROM users WHERE id=1", SQLTypeDelete}, + {"CREATE TABLE", "CREATE TABLE users (id UUID PRIMARY KEY)", SQLTypeDDL}, + {"ALTER TABLE", "ALTER TABLE users ADD COLUMN email STRING", SQLTypeDDL}, + {"DROP TABLE", "DROP TABLE users", SQLTypeDDL}, + {"TRUNCATE", "TRUNCATE TABLE users", SQLTypeTruncate}, + {"EXPLAIN", "EXPLAIN SELECT * FROM users", SQLTypeExplain}, + {"SHOW", "SHOW TABLES", SQLTypeShow}, + {"SET", "SET application_name = 'myapp'", SQLTypeSet}, + {"Empty", "", SQLTypeUnknown}, + {"Lowercase select", "select * from users", SQLTypeSelect}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ClassifySQL(tt.sql) + if result != tt.expected { + t.Errorf("ClassifySQL(%q) = %v, want %v", tt.sql, result, tt.expected) + } + }) + } +} + +// TestIsWriteOperation tests write operation detection +func TestIsWriteOperation(t *testing.T) { + tests := []struct { + sqlType SQLStatementType + expected bool + }{ + {SQLTypeSelect, false}, + {SQLTypeInsert, true}, + {SQLTypeUpdate, true}, + {SQLTypeDelete, true}, + {SQLTypeTruncate, true}, + {SQLTypeDDL, true}, + {SQLTypeExplain, false}, + {SQLTypeShow, false}, + {SQLTypeSet, false}, + {SQLTypeUnknown, false}, + } + + for _, tt := range tests { + t.Run(tt.sqlType.String(), func(t *testing.T) { + result := IsWriteOperation(tt.sqlType) + if result != tt.expected { + t.Errorf("IsWriteOperation(%v) = %v, want %v", tt.sqlType, result, tt.expected) + } + }) + } +} + +// Helper for SQLStatementType to string +func (s SQLStatementType) String() string { + switch s { + case SQLTypeSelect: + return "SELECT" + case SQLTypeInsert: + return "INSERT" + case SQLTypeUpdate: + return "UPDATE" + case SQLTypeDelete: + return "DELETE" + case SQLTypeDDL: + return "DDL" + case SQLTypeTruncate: + return "TRUNCATE" + case SQLTypeExplain: + return "EXPLAIN" + case SQLTypeShow: + return "SHOW" + case SQLTypeSet: + return "SET" + default: + return "UNKNOWN" + } +} + +// TestCanExecuteWrite tests write operation enforcement +func TestCanExecuteWrite(t *testing.T) { + tests := []struct { + name string + readOnlyMode bool + enableWriteMode bool + sql string + expectError bool + errorCode string + }{ + { + name: "SELECT in read-only mode", + readOnlyMode: true, + enableWriteMode: false, + sql: "SELECT * FROM users", + expectError: false, + }, + { + name: "INSERT in read-only mode", + readOnlyMode: true, + enableWriteMode: false, + sql: "INSERT INTO users (name) VALUES ('alice')", + expectError: true, + errorCode: ErrCodeReadOnlyViolation, + }, + { + name: "INSERT with write mode enabled", + readOnlyMode: false, + enableWriteMode: true, + sql: "INSERT INTO users (name) VALUES ('alice')", + expectError: false, + }, + { + name: "CREATE TABLE in read-only mode", + readOnlyMode: true, + enableWriteMode: false, + sql: "CREATE TABLE test (id UUID PRIMARY KEY)", + expectError: true, + errorCode: ErrCodeReadOnlyViolation, + }, + { + name: "CREATE TABLE with write mode enabled", + readOnlyMode: false, + enableWriteMode: true, + sql: "CREATE TABLE test (id UUID PRIMARY KEY)", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + source := &Source{ + Config: Config{ + ReadOnlyMode: tt.readOnlyMode, + EnableWriteMode: tt.enableWriteMode, + }, + } + + err := source.CanExecuteWrite(tt.sql) + + if tt.expectError { + if err == nil { + t.Errorf("Expected error but got nil") + return + } + + structErr, ok := err.(*StructuredError) + if !ok { + t.Errorf("Expected StructuredError but got %T", err) + return + } + + if structErr.Code != tt.errorCode { + t.Errorf("Expected error code %s but got %s", tt.errorCode, structErr.Code) + } + } else { + if err != nil { + t.Errorf("Expected no error but got: %v", err) + } + } + }) + } +} + +// TestApplyQueryLimits tests query limit application +func TestApplyQueryLimits(t *testing.T) { + tests := []struct { + name string + sql string + maxRowLimit int + expectedSQL string + shouldAddLimit bool + }{ + { + name: "SELECT without LIMIT", + sql: "SELECT * FROM users", + maxRowLimit: 100, + expectedSQL: "SELECT * FROM users LIMIT 100", + shouldAddLimit: true, + }, + { + name: "SELECT with existing LIMIT", + sql: "SELECT * FROM users LIMIT 50", + maxRowLimit: 100, + expectedSQL: "SELECT * FROM users LIMIT 50", + shouldAddLimit: false, + }, + { + name: "SELECT without LIMIT and semicolon", + sql: "SELECT * FROM users;", + maxRowLimit: 100, + expectedSQL: "SELECT * FROM users LIMIT 100", + shouldAddLimit: true, + }, + { + name: "SELECT with trailing newline and semicolon", + sql: "SELECT * FROM users;\n", + maxRowLimit: 100, + expectedSQL: "SELECT * FROM users LIMIT 100", + shouldAddLimit: true, + }, + { + name: "SELECT with multiline and semicolon", + sql: "\n\tSELECT *\n\tFROM users\n\tORDER BY id;\n", + maxRowLimit: 100, + expectedSQL: "SELECT *\n\tFROM users\n\tORDER BY id LIMIT 100", + shouldAddLimit: true, + }, + { + name: "INSERT should not have LIMIT added", + sql: "INSERT INTO users (name) VALUES ('alice')", + maxRowLimit: 100, + expectedSQL: "INSERT INTO users (name) VALUES ('alice')", + shouldAddLimit: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + source := &Source{ + Config: Config{ + MaxRowLimit: tt.maxRowLimit, + QueryTimeoutSec: 0, // Timeout now managed by caller + }, + } + + modifiedSQL, err := source.ApplyQueryLimits(tt.sql) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if modifiedSQL != tt.expectedSQL { + t.Errorf("Expected SQL:\n%s\nGot:\n%s", tt.expectedSQL, modifiedSQL) + } + }) + } +} + +// TestApplyQueryTimeout tests that timeout is managed by caller (not source) +func TestApplyQueryTimeout(t *testing.T) { + source := &Source{ + Config: Config{ + QueryTimeoutSec: 5, // Documented recommended timeout + MaxRowLimit: 0, // Don't add LIMIT + }, + } + + // Caller creates timeout context (following Go best practices) + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Duration(source.QueryTimeoutSec)*time.Second) + defer cancel() + + // Apply query limits (doesn't modify context anymore) + modifiedSQL, err := source.ApplyQueryLimits("SELECT * FROM users") + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + // Verify context has deadline (managed by caller) + deadline, ok := ctx.Deadline() + if !ok { + t.Error("Expected deadline to be set but it wasn't") + return + } + + // Verify deadline is approximately 5 seconds from now + expectedDeadline := time.Now().Add(5 * time.Second) + diff := deadline.Sub(expectedDeadline) + if diff < 0 { + diff = -diff + } + + // Allow 1 second tolerance + if diff > time.Second { + t.Errorf("Deadline diff too large: %v", diff) + } + + // Verify SQL is unchanged (LIMIT not added since MaxRowLimit=0) + if modifiedSQL != "SELECT * FROM users" { + t.Errorf("Expected SQL unchanged, got: %s", modifiedSQL) + } +} + +// TestRedactSQL tests SQL redaction for telemetry +func TestRedactSQL(t *testing.T) { + tests := []struct { + name string + sql string + expected string + }{ + { + name: "String literal redaction", + sql: "SELECT * FROM users WHERE name='alice' AND email='alice@example.com'", + expected: "SELECT * FROM users WHERE name='***' AND email='***'", + }, + { + name: "Long number redaction", + sql: "SELECT * FROM users WHERE ssn=1234567890123", + expected: "SELECT * FROM users WHERE ssn=***", + }, + { + name: "Short numbers not redacted", + sql: "SELECT * FROM users WHERE age=25", + expected: "SELECT * FROM users WHERE age=25", + }, + { + name: "Multiple sensitive values", + sql: "INSERT INTO users (name, email, phone) VALUES ('bob', 'bob@example.com', '5551234567')", + expected: "INSERT INTO users (name, email, phone) VALUES ('***', '***', '***')", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := RedactSQL(tt.sql) + if result != tt.expected { + t.Errorf("RedactSQL:\nGot: %s\nExpected: %s", result, tt.expected) + } + }) + } +} + +// TestIsReadOnlyMode tests read-only mode detection +func TestIsReadOnlyMode(t *testing.T) { + tests := []struct { + name string + readOnlyMode bool + enableWriteMode bool + expected bool + }{ + {"Read-only by default", true, false, true}, + {"Write mode enabled", false, true, false}, + {"Both false", false, false, false}, + {"Read-only overridden by write mode", true, true, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + source := &Source{ + Config: Config{ + ReadOnlyMode: tt.readOnlyMode, + EnableWriteMode: tt.enableWriteMode, + }, + } + + result := source.IsReadOnlyMode() + if result != tt.expected { + t.Errorf("IsReadOnlyMode() = %v, want %v", result, tt.expected) + } + }) + } +} + +// TestStructuredError tests error formatting +func TestStructuredError(t *testing.T) { + err := &StructuredError{ + Code: ErrCodeReadOnlyViolation, + Message: "Write operations not allowed", + Details: map[string]any{ + "sql_type": "INSERT", + }, + } + + errorStr := err.Error() + if !strings.Contains(errorStr, ErrCodeReadOnlyViolation) { + t.Errorf("Error string should contain error code: %s", errorStr) + } + if !strings.Contains(errorStr, "Write operations not allowed") { + t.Errorf("Error string should contain message: %s", errorStr) + } +} + +// TestDefaultSecuritySettings tests that security defaults are correct +func TestDefaultSecuritySettings(t *testing.T) { + ctx := context.Background() + + // Create a minimal YAML config + yamlData := `name: test +type: cockroachdb +host: localhost +port: "26257" +user: root +database: defaultdb +` + + var cfg Config + if err := yaml.Unmarshal([]byte(yamlData), &cfg); err != nil { + t.Fatalf("Failed to unmarshal YAML: %v", err) + } + + // Apply defaults through newConfig logic manually + cfg.MaxRetries = 5 + cfg.RetryBaseDelay = "500ms" + cfg.ReadOnlyMode = true + cfg.EnableWriteMode = false + cfg.MaxRowLimit = 1000 + cfg.QueryTimeoutSec = 30 + cfg.EnableTelemetry = true + cfg.TelemetryVerbose = false + + _ = ctx // prevent unused + + // Verify MCP security defaults + if !cfg.ReadOnlyMode { + t.Error("ReadOnlyMode should be true by default") + } + if cfg.EnableWriteMode { + t.Error("EnableWriteMode should be false by default") + } + if cfg.MaxRowLimit != 1000 { + t.Errorf("MaxRowLimit should be 1000, got %d", cfg.MaxRowLimit) + } + if cfg.QueryTimeoutSec != 30 { + t.Errorf("QueryTimeoutSec should be 30, got %d", cfg.QueryTimeoutSec) + } + if !cfg.EnableTelemetry { + t.Error("EnableTelemetry should be true by default") + } +} diff --git a/internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql.go b/internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql.go new file mode 100644 index 0000000000..7bd4f07345 --- /dev/null +++ b/internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql.go @@ -0,0 +1,186 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdbexecutesql + +import ( + "context" + "fmt" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/sources/cockroachdb" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + "github.com/jackc/pgx/v5" +) + +const kind string = "cockroachdb-execute-sql" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) +} + +var compatibleSources = [...]string{cockroachdb.SourceKind} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` +} + +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) ToolConfigType() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + + _, ok = rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + } + + sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") + params := parameters.Parameters{sqlParameter} + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) + + t := Tool{ + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +var _ tools.Tool = Tool{} + +type Tool struct { + Config + Parameters parameters.Parameters `yaml:"parameters"` + + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return nil, err + } + + paramsMap := params.AsMap() + sql, ok := paramsMap["sql"].(string) + if !ok { + return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"]) + } + logger, err := util.LoggerFromContext(ctx) + if err != nil { + return nil, fmt.Errorf("error getting logger: %s", err) + } + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) + + results, err := source.Query(ctx, sql) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + fields := results.FieldDescriptions() + + var out []any + for results.Next() { + v, err := results.Values() + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + row := orderedmap.Row{} + for i, f := range fields { + row.Add(f.Name, v[i]) + } + out = append(out, row) + } + + if err := results.Err(); err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + + return out, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.Parameters, data, claims) +} + +func (t Tool) EmbedParams(ctx context.Context, params parameters.ParamValues, models map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { + return params, nil +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(_ tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} + +func (t Tool) GetParameters() parameters.Parameters { + return t.Parameters +} diff --git a/internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql_test.go b/internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql_test.go new file mode 100644 index 0000000000..a7c8726a90 --- /dev/null +++ b/internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql_test.go @@ -0,0 +1,81 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdbexecutesql_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdbexecutesql" +) + +func TestParseFromYamlCockroachDBExecuteSQL(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + kind: tools + name: execute_sql_tool + type: cockroachdb-execute-sql + source: my-crdb-instance + description: Execute SQL on CockroachDB + `, + want: server.ToolConfigs{ + "execute_sql_tool": cockroachdbexecutesql.Config{ + Name: "execute_sql_tool", + Type: "cockroachdb-execute-sql", + Source: "my-crdb-instance", + Description: "Execute SQL on CockroachDB", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} + +func TestCockroachDBExecuteSQLToolConfigType(t *testing.T) { + cfg := cockroachdbexecutesql.Config{ + Name: "test-tool", + Type: "cockroachdb-execute-sql", + Source: "test-source", + Description: "test description", + } + + if cfg.ToolConfigType() != "cockroachdb-execute-sql" { + t.Errorf("expected ToolConfigType 'cockroachdb-execute-sql', got %q", cfg.ToolConfigType()) + } +} diff --git a/internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas.go b/internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas.go new file mode 100644 index 0000000000..2a5c2dbc8e --- /dev/null +++ b/internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas.go @@ -0,0 +1,187 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdblistschemas + +import ( + "context" + "fmt" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/sources/cockroachdb" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + "github.com/jackc/pgx/v5" +) + +const kind string = "cockroachdb-list-schemas" + +const listSchemasStatement = ` + SELECT + catalog_name, + schema_name, + crdb_is_user_defined + FROM information_schema.schemata + WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'crdb_internal', 'pg_extension') + AND schema_name NOT LIKE 'pg_temp_%' + AND schema_name NOT LIKE 'pg_toast_temp_%' + ORDER BY schema_name; +` + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) +} + +var _ compatibleSource = &cockroachdb.Source{} + +var compatibleSources = [...]string{cockroachdb.SourceKind} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` +} + +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) ToolConfigType() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + + _, ok = rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + } + + allParameters := parameters.Parameters{} + paramManifest := allParameters.Manifest() + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) + + t := Tool{ + Config: cfg, + AllParams: allParameters, + manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + + return t, nil +} + +var _ tools.Tool = Tool{} + +type Tool struct { + Config + AllParams parameters.Parameters `yaml:"allParams"` + + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return nil, err + } + + results, err := source.Query(ctx, listSchemasStatement) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + fields := results.FieldDescriptions() + var out []map[string]any + + for results.Next() { + values, err := results.Values() + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + rowMap := make(map[string]any) + for i, field := range fields { + rowMap[string(field.Name)] = values[i] + } + out = append(out, rowMap) + } + + if err := results.Err(); err != nil { + return nil, fmt.Errorf("error reading query results: %w", err) + } + + return out, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.AllParams, data, claims) +} + +func (t Tool) EmbedParams(ctx context.Context, params parameters.ParamValues, models map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { + return params, nil +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(_ tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} + +func (t Tool) GetParameters() parameters.Parameters { + return t.AllParams +} diff --git a/internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas_test.go b/internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas_test.go new file mode 100644 index 0000000000..260187641f --- /dev/null +++ b/internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas_test.go @@ -0,0 +1,81 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdblistschemas_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdblistschemas" +) + +func TestParseFromYamlCockroachDBListSchemas(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + kind: tools + name: list_schemas_tool + type: cockroachdb-list-schemas + source: my-crdb-instance + description: List schemas in CockroachDB + `, + want: server.ToolConfigs{ + "list_schemas_tool": cockroachdblistschemas.Config{ + Name: "list_schemas_tool", + Type: "cockroachdb-list-schemas", + Source: "my-crdb-instance", + Description: "List schemas in CockroachDB", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} + +func TestCockroachDBListSchemasToolConfigType(t *testing.T) { + cfg := cockroachdblistschemas.Config{ + Name: "test-tool", + Type: "cockroachdb-list-schemas", + Source: "test-source", + Description: "test description", + } + + if cfg.ToolConfigType() != "cockroachdb-list-schemas" { + t.Errorf("expected ToolConfigType 'cockroachdb-list-schemas', got %q", cfg.ToolConfigType()) + } +} diff --git a/internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables.go b/internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables.go new file mode 100644 index 0000000000..254ee3b658 --- /dev/null +++ b/internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables.go @@ -0,0 +1,261 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdblisttables + +import ( + "context" + "fmt" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/sources/cockroachdb" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + "github.com/jackc/pgx/v5" +) + +const kind string = "cockroachdb-list-tables" + +const listTablesStatement = ` + WITH desired_relkinds AS ( + SELECT ARRAY['r', 'p']::char[] AS kinds + ), + table_info AS ( + SELECT + t.oid AS table_oid, + ns.nspname AS schema_name, + t.relname AS table_name, + pg_get_userbyid(t.relowner) AS table_owner, + obj_description(t.oid, 'pg_class') AS table_comment, + t.relkind AS object_kind + FROM + pg_class t + JOIN + pg_namespace ns ON ns.oid = t.relnamespace + CROSS JOIN desired_relkinds dk + WHERE + t.relkind = ANY(dk.kinds) + AND (NULLIF(TRIM($1), '') IS NULL OR t.relname = ANY(string_to_array($1,','))) + AND ns.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'crdb_internal', 'pg_extension') + AND ns.nspname NOT LIKE 'pg_temp_%' AND ns.nspname NOT LIKE 'pg_toast_temp_%' + ), + columns_info AS ( + SELECT + att.attrelid AS table_oid, att.attname AS column_name, format_type(att.atttypid, att.atttypmod) AS data_type, + att.attnum AS column_ordinal_position, att.attnotnull AS is_not_nullable, + pg_get_expr(ad.adbin, ad.adrelid) AS column_default, col_description(att.attrelid, att.attnum) AS column_comment + FROM pg_attribute att LEFT JOIN pg_attrdef ad ON att.attrelid = ad.adrelid AND att.attnum = ad.adnum + JOIN table_info ti ON att.attrelid = ti.table_oid WHERE att.attnum > 0 AND NOT att.attisdropped + ), + constraints_info AS ( + SELECT + con.conrelid AS table_oid, con.conname AS constraint_name, pg_get_constraintdef(con.oid) AS constraint_definition, + CASE con.contype WHEN 'p' THEN 'PRIMARY KEY' WHEN 'f' THEN 'FOREIGN KEY' WHEN 'u' THEN 'UNIQUE' WHEN 'c' THEN 'CHECK' ELSE con.contype::text END AS constraint_type, + (SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.conkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.conrelid AND att.attnum = u.attnum) AS constraint_columns, + NULLIF(con.confrelid, 0)::regclass AS foreign_key_referenced_table, + (SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.confkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.confrelid AND att.attnum = u.attnum WHERE con.contype = 'f') AS foreign_key_referenced_columns + FROM pg_constraint con JOIN table_info ti ON con.conrelid = ti.table_oid + ), + indexes_info AS ( + SELECT + idx.indrelid AS table_oid, ic.relname AS index_name, pg_get_indexdef(idx.indexrelid) AS index_definition, + idx.indisunique AS is_unique, idx.indisprimary AS is_primary, am.amname AS index_method, + (SELECT array_agg(att.attname ORDER BY u.ord) FROM unnest(idx.indkey::int[]) WITH ORDINALITY AS u(colidx, ord) LEFT JOIN pg_attribute att ON att.attrelid = idx.indrelid AND att.attnum = u.colidx WHERE u.colidx <> 0) AS index_columns + FROM pg_index idx JOIN pg_class ic ON ic.oid = idx.indexrelid JOIN pg_am am ON am.oid = ic.relam JOIN table_info ti ON idx.indrelid = ti.table_oid + ) + SELECT + ti.schema_name, + ti.table_name AS object_name, + CASE + WHEN $2 = 'simple' THEN + json_build_object('name', ti.table_name) + ELSE + json_build_object( + 'schema_name', ti.schema_name, + 'object_name', ti.table_name, + 'object_type', CASE ti.object_kind + WHEN 'r' THEN 'TABLE' + WHEN 'p' THEN 'PARTITIONED TABLE' + ELSE ti.object_kind::text + END, + 'owner', ti.table_owner, + 'comment', ti.table_comment, + 'columns', COALESCE((SELECT json_agg(json_build_object('column_name',ci.column_name,'data_type',ci.data_type,'ordinal_position',ci.column_ordinal_position,'is_not_nullable',ci.is_not_nullable,'column_default',ci.column_default,'column_comment',ci.column_comment) ORDER BY ci.column_ordinal_position) FROM columns_info ci WHERE ci.table_oid = ti.table_oid), '[]'::json), + 'constraints', COALESCE((SELECT json_agg(json_build_object('constraint_name',cons.constraint_name,'constraint_type',cons.constraint_type,'constraint_definition',cons.constraint_definition,'constraint_columns',cons.constraint_columns,'foreign_key_referenced_table',cons.foreign_key_referenced_table,'foreign_key_referenced_columns',cons.foreign_key_referenced_columns)) FROM constraints_info cons WHERE cons.table_oid = ti.table_oid), '[]'::json), + 'indexes', COALESCE((SELECT json_agg(json_build_object('index_name',ii.index_name,'index_definition',ii.index_definition,'is_unique',ii.is_unique,'is_primary',ii.is_primary,'index_method',ii.index_method,'index_columns',ii.index_columns)) FROM indexes_info ii WHERE ii.table_oid = ti.table_oid), '[]'::json) + ) + END AS object_details + FROM table_info ti ORDER BY ti.schema_name, ti.table_name; +` + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) +} + +var _ compatibleSource = &cockroachdb.Source{} + +var compatibleSources = [...]string{cockroachdb.SourceKind} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` +} + +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) ToolConfigType() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + + _, ok = rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + } + + allParameters := parameters.Parameters{ + parameters.NewStringParameterWithDefault("table_names", "", "Optional: A comma-separated list of table names. If empty, details for all tables will be listed."), + parameters.NewStringParameterWithDefault("output_format", "detailed", "Optional: Use 'simple' for names only or 'detailed' for full info."), + } + paramManifest := allParameters.Manifest() + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) + + t := Tool{ + Config: cfg, + AllParams: allParameters, + manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + + return t, nil +} + +var _ tools.Tool = Tool{} + +type Tool struct { + Config + AllParams parameters.Parameters `yaml:"allParams"` + + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return nil, err + } + + paramsMap := params.AsMap() + + tableNames, ok := paramsMap["table_names"].(string) + if !ok { + return nil, fmt.Errorf("invalid 'table_names' parameter; expected a string") + } + outputFormat, _ := paramsMap["output_format"].(string) + if outputFormat != "simple" && outputFormat != "detailed" { + return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat) + } + + results, err := source.Query(ctx, listTablesStatement, tableNames, outputFormat) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + fields := results.FieldDescriptions() + var out []map[string]any + + for results.Next() { + values, err := results.Values() + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + rowMap := make(map[string]any) + for i, field := range fields { + rowMap[string(field.Name)] = values[i] + } + out = append(out, rowMap) + } + + if err := results.Err(); err != nil { + return nil, fmt.Errorf("error reading query results: %w", err) + } + + return out, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.AllParams, data, claims) +} + +func (t Tool) EmbedParams(ctx context.Context, params parameters.ParamValues, models map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { + return params, nil +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(_ tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} + +func (t Tool) GetParameters() parameters.Parameters { + return t.AllParams +} diff --git a/internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables_test.go b/internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables_test.go new file mode 100644 index 0000000000..60516919b4 --- /dev/null +++ b/internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables_test.go @@ -0,0 +1,81 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdblisttables_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdblisttables" +) + +func TestParseFromYamlCockroachDBListTables(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + kind: tools + name: list_tables_tool + type: cockroachdb-list-tables + source: my-crdb-instance + description: List tables in CockroachDB + `, + want: server.ToolConfigs{ + "list_tables_tool": cockroachdblisttables.Config{ + Name: "list_tables_tool", + Type: "cockroachdb-list-tables", + Source: "my-crdb-instance", + Description: "List tables in CockroachDB", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} + +func TestCockroachDBListTablesToolConfigType(t *testing.T) { + cfg := cockroachdblisttables.Config{ + Name: "test-tool", + Type: "cockroachdb-list-tables", + Source: "test-source", + Description: "test description", + } + + if cfg.ToolConfigType() != "cockroachdb-list-tables" { + t.Errorf("expected ToolConfigType 'cockroachdb-list-tables', got %q", cfg.ToolConfigType()) + } +} diff --git a/internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql.go b/internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql.go new file mode 100644 index 0000000000..33b1830545 --- /dev/null +++ b/internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql.go @@ -0,0 +1,192 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdbsql + +import ( + "context" + "fmt" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/sources/cockroachdb" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + "github.com/jackc/pgx/v5" +) + +const kind string = "cockroachdb-sql" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) +} + +var _ compatibleSource = &cockroachdb.Source{} + +var compatibleSources = [...]string{cockroachdb.SourceKind} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + Statement string `yaml:"statement" validate:"required"` + AuthRequired []string `yaml:"authRequired"` + Parameters parameters.Parameters `yaml:"parameters"` + TemplateParameters parameters.Parameters `yaml:"templateParameters"` +} + +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) ToolConfigType() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + + _, ok = rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + } + + allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) + if err != nil { + return nil, err + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) + + t := Tool{ + Config: cfg, + AllParams: allParameters, + manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +var _ tools.Tool = Tool{} + +type Tool struct { + Config + AllParams parameters.Parameters `yaml:"allParams"` + + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return nil, err + } + + paramsMap := params.AsMap() + newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) + if err != nil { + return nil, fmt.Errorf("unable to extract template params %w", err) + } + + newParams, err := parameters.GetParams(t.Parameters, paramsMap) + if err != nil { + return nil, fmt.Errorf("unable to extract standard params %w", err) + } + sliceParams := newParams.AsSlice() + results, err := source.Query(ctx, newStatement, sliceParams...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + fields := results.FieldDescriptions() + + var out []any + for results.Next() { + v, err := results.Values() + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + row := orderedmap.Row{} + for i, f := range fields { + row.Add(f.Name, v[i]) + } + out = append(out, row) + } + + if err := results.Err(); err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + + return out, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.AllParams, data, claims) +} + +func (t Tool) EmbedParams(ctx context.Context, params parameters.ParamValues, models map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { + return params, nil +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(_ tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} + +func (t Tool) GetParameters() parameters.Parameters { + return t.AllParams +} diff --git a/internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql_test.go b/internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql_test.go new file mode 100644 index 0000000000..1cca50a280 --- /dev/null +++ b/internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql_test.go @@ -0,0 +1,93 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdbsql_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdbsql" + "github.com/googleapis/genai-toolbox/internal/util/parameters" +) + +func TestParseFromYamlCockroachDB(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + kind: tools + name: example_tool + type: cockroachdb-sql + source: my-crdb-instance + description: some description + statement: | + SELECT * FROM SQL_STATEMENT; + parameters: + - name: user_id + type: string + description: user id parameter + `, + want: server.ToolConfigs{ + "example_tool": cockroachdbsql.Config{ + Name: "example_tool", + Type: "cockroachdb-sql", + Source: "my-crdb-instance", + Description: "some description", + Statement: "SELECT * FROM SQL_STATEMENT;\n", + AuthRequired: []string{}, + Parameters: []parameters.Parameter{ + parameters.NewStringParameter("user_id", "user id parameter"), + }, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} + +func TestCockroachDBSQLToolConfigType(t *testing.T) { + cfg := cockroachdbsql.Config{ + Name: "test-tool", + Type: "cockroachdb-sql", + Source: "test-source", + Description: "test description", + Statement: "SELECT 1", + } + + if cfg.ToolConfigType() != "cockroachdb-sql" { + t.Errorf("expected ToolConfigType 'cockroachdb-sql', got %q", cfg.ToolConfigType()) + } +} diff --git a/tests/cockroachdb/cockroachdb_integration_test.go b/tests/cockroachdb/cockroachdb_integration_test.go new file mode 100644 index 0000000000..43abb36207 --- /dev/null +++ b/tests/cockroachdb/cockroachdb_integration_test.go @@ -0,0 +1,220 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdb + +import ( + "context" + "fmt" + "net/url" + "os" + "regexp" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/tests" + "github.com/jackc/pgx/v5/pgxpool" +) + +var ( + CockroachDBSourceKind = "cockroachdb" + CockroachDBToolKind = "cockroachdb-sql" + CockroachDBDatabase = getEnvOrDefault("COCKROACHDB_DATABASE", "defaultdb") + CockroachDBHost = getEnvOrDefault("COCKROACHDB_HOST", "localhost") + CockroachDBPort = getEnvOrDefault("COCKROACHDB_PORT", "26257") + CockroachDBUser = getEnvOrDefault("COCKROACHDB_USER", "root") + CockroachDBPass = getEnvOrDefault("COCKROACHDB_PASS", "") +) + +func getEnvOrDefault(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue +} + +func getCockroachDBVars(t *testing.T) map[string]any { + if CockroachDBHost == "" { + t.Skip("COCKROACHDB_HOST not set, skipping CockroachDB integration test") + } + + return map[string]any{ + "type": CockroachDBSourceKind, + "host": CockroachDBHost, + "port": CockroachDBPort, + "database": CockroachDBDatabase, + "user": CockroachDBUser, + "password": CockroachDBPass, + "maxRetries": 5, + "retryBaseDelay": "500ms", + "queryParams": map[string]string{ + "sslmode": "disable", + }, + } +} + +func initCockroachDBConnectionPool(host, port, user, pass, dbname string) (*pgxpool.Pool, error) { + connURL := &url.URL{ + Scheme: "postgres", + User: url.UserPassword(user, pass), + Host: fmt.Sprintf("%s:%s", host, port), + Path: dbname, + RawQuery: "sslmode=disable&application_name=cockroachdb-integration-test", + } + pool, err := pgxpool.New(context.Background(), connURL.String()) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + return pool, nil +} + +func TestCockroachDB(t *testing.T) { + sourceConfig := getCockroachDBVars(t) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + var args []string + + pool, err := initCockroachDBConnectionPool(CockroachDBHost, CockroachDBPort, CockroachDBUser, CockroachDBPass, CockroachDBDatabase) + if err != nil { + t.Fatalf("unable to create cockroachdb connection pool: %s", err) + } + // Note: Don't defer pool.Close() here - the pool is only used for test setup/teardown. + // Closing it explicitly can cause hangs if the server's pool is still active. + // The pool will be cleaned up when the test exits. + + // Verify CockroachDB version + var version string + err = pool.QueryRow(ctx, "SELECT version()").Scan(&version) + if err != nil { + t.Fatalf("failed to query version: %s", err) + } + if !strings.Contains(version, "CockroachDB") { + t.Fatalf("not connected to CockroachDB, got: %s", version) + } + t.Logf("✅ Connected to: %s", version) + + // cleanup test environment + tests.CleanupPostgresTables(t, ctx, pool) + + // create table names with UUID suffix + tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + + // set up data for param tool (using CockroachDB explicit INT primary keys) + createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetCockroachDBParamToolInfo(tableNameParam) + teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams) + defer teardownTable1(t) + + // set up data for auth tool + createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetCockroachDBAuthToolInfo(tableNameAuth) + teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams) + defer teardownTable2(t) + + // Write config into a file and pass it to command + toolsFile := tests.GetToolsConfig(sourceConfig, CockroachDBToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt) + + // Add execute-sql tool with write-enabled source (CockroachDB MCP security requires explicit opt-in) + toolsFile = addCockroachDBExecuteSqlConfig(t, toolsFile, sourceConfig) + + tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement() + toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CockroachDBToolKind, tmplSelectCombined, tmplSelectFilterCombined, "") + + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + // Get configs for tests (use CockroachDB-specific expectations) + select1Want, mcpMyFailToolWant, createTableStatement, mcpSelect1Want := tests.GetCockroachDBWants() + + // Run required integration test suites (per CONTRIBUTING.md) + t.Run("ToolGetTest", func(t *testing.T) { + tests.RunToolGetTest(t) + }) + + t.Run("ToolInvokeTest", func(t *testing.T) { + tests.RunToolInvokeTest(t, select1Want) + }) + + t.Run("MCPToolCallMethod", func(t *testing.T) { + tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want) + }) + + t.Run("ExecuteSqlToolInvokeTest", func(t *testing.T) { + tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want) + }) + + t.Run("ToolInvokeWithTemplateParameters", func(t *testing.T) { + tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam) + }) + + t.Logf("✅✅✅ All CockroachDB integration tests passed!") +} + +// addCockroachDBExecuteSqlConfig adds execute-sql tool with write-enabled source +// CockroachDB has MCP security enabled by default, so execute-sql needs a separate source with enableWriteMode +func addCockroachDBExecuteSqlConfig(t *testing.T, config map[string]any, baseSourceConfig map[string]any) map[string]any { + // Add write-enabled source for execute-sql tool + sources, ok := config["sources"].(map[string]any) + if !ok { + t.Fatalf("unable to get sources from config") + } + + // Create a copy of the base source config with write mode enabled + writeEnabledSource := make(map[string]any) + for k, v := range baseSourceConfig { + writeEnabledSource[k] = v + } + writeEnabledSource["enableWriteMode"] = true + writeEnabledSource["readOnlyMode"] = false + + sources["my-write-instance"] = writeEnabledSource + + // Add tools using the write-enabled source + tools, ok := config["tools"].(map[string]any) + if !ok { + t.Fatalf("unable to get tools from config") + } + + tools["my-exec-sql-tool"] = map[string]any{ + "type": "cockroachdb-execute-sql", + "source": "my-write-instance", + "description": "Tool to execute sql", + } + tools["my-auth-exec-sql-tool"] = map[string]any{ + "type": "cockroachdb-execute-sql", + "source": "my-write-instance", + "description": "Tool to execute sql", + "authRequired": []string{ + "my-google-auth", + }, + } + + return config +} diff --git a/tests/common.go b/tests/common.go index fb56159e51..d200d59dd6 100644 --- a/tests/common.go +++ b/tests/common.go @@ -528,6 +528,41 @@ func GetPostgresSQLTmplToolStatement() (string, string) { return tmplSelectCombined, tmplSelectFilterCombined } +// GetCockroachDBParamToolInfo returns statements and param for my-tool cockroachdb-sql type +// Uses explicit INT PRIMARY KEY instead of SERIAL to ensure deterministic IDs +func GetCockroachDBParamToolInfo(tableName string) (string, string, string, string, string, string, []any) { + createStatement := fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, name TEXT);", tableName) + insertStatement := fmt.Sprintf("INSERT INTO %s (id, name) VALUES (1, $1), (2, $2), (3, $3), (4, $4);", tableName) + toolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = $1 OR name = $2 ORDER BY id;", tableName) + idParamStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = $1;", tableName) + nameParamStatement := fmt.Sprintf("SELECT * FROM %s WHERE name = $1;", tableName) + arrayToolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = ANY($1) AND name = ANY($2) ORDER BY id;", tableName) + params := []any{"Alice", "Jane", "Sid", nil} + return createStatement, insertStatement, toolStatement, idParamStatement, nameParamStatement, arrayToolStatement, params +} + +// GetCockroachDBAuthToolInfo returns statements and param of my-auth-tool for cockroachdb-sql type +// Uses explicit INT PRIMARY KEY instead of SERIAL to ensure deterministic IDs +func GetCockroachDBAuthToolInfo(tableName string) (string, string, string, []any) { + createStatement := fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, name TEXT, email TEXT);", tableName) + insertStatement := fmt.Sprintf("INSERT INTO %s (id, name, email) VALUES (1, $1, $2), (2, $3, $4)", tableName) + toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = $1;", tableName) + params := []any{"Alice", ServiceAccountEmail, "Jane", "janedoe@gmail.com"} + return createStatement, insertStatement, toolStatement, params +} + +// GetCockroachDBWants return the expected wants for cockroachdb +func GetCockroachDBWants() (string, string, string, string) { + select1Want := "[{\"?column?\":1}]" + // CockroachDB formats syntax errors differently than PostgreSQL: + // - Uses lowercase for SQL keywords in error messages + // - Uses format: 'at or near "token": syntax error' instead of 'syntax error at or near "TOKEN"' + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: ERROR: at or near \"selec\": syntax error (SQLSTATE 42601)"}],"isError":true}}` + createTableStatement := `"CREATE TABLE t (id INT PRIMARY KEY, name TEXT)"` + mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"?column?\":1}"}]}}` + return select1Want, mcpMyFailToolWant, createTableStatement, mcpSelect1Want +} + // GetMSSQLParamToolInfo returns statements and param for my-tool mssql-sql type func GetMSSQLParamToolInfo(tableName string) (string, string, string, string, string, string, []any) { createStatement := fmt.Sprintf("CREATE TABLE %s (id INT IDENTITY(1,1) PRIMARY KEY, name VARCHAR(255));", tableName) From 32cb4db712d27579c1bf29e61cbd0bed02286c28 Mon Sep 17 00:00:00 2001 From: Wenxin Du <117315983+duwenxin99@users.noreply.github.com> Date: Thu, 12 Feb 2026 08:34:28 -0500 Subject: [PATCH 2/9] feat(server): add Tool call error categories (#2387) Create Agent vs Server error types to distinguish between the two types. --------- Co-authored-by: Averi Kitsch --- internal/util/errors.go | 77 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 internal/util/errors.go diff --git a/internal/util/errors.go b/internal/util/errors.go new file mode 100644 index 0000000000..38dd7f5954 --- /dev/null +++ b/internal/util/errors.go @@ -0,0 +1,77 @@ +// Copyright 2026 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package util + +import "fmt" + +type ErrorCategory string + +const ( + CategoryAgent ErrorCategory = "AGENT_ERROR" + CategoryServer ErrorCategory = "SERVER_ERROR" +) + +// ToolboxError is the interface all custom errors must satisfy +type ToolboxError interface { + error + Category() ErrorCategory + Error() string + Unwrap() error +} + +// Agent Errors return 200 to the sender +type AgentError struct { + Msg string + Cause error +} + +var _ ToolboxError = &AgentError{} + +func (e *AgentError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("%s: %v", e.Msg, e.Cause) + } + return e.Msg +} + +func (e *AgentError) Category() ErrorCategory { return CategoryAgent } + +func (e *AgentError) Unwrap() error { return e.Cause } + +func NewAgentError(msg string, cause error) *AgentError { + return &AgentError{Msg: msg, Cause: cause} +} + +// ClientServerError returns 4XX/5XX error code +type ClientServerError struct { + Msg string + Code int + Cause error +} + +var _ ToolboxError = &ClientServerError{} + +func (e *ClientServerError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("%s: %v", e.Msg, e.Cause) + } + return e.Msg +} + +func (e *ClientServerError) Category() ErrorCategory { return CategoryServer } + +func (e *ClientServerError) Unwrap() error { return e.Cause } + +func NewClientServerError(msg string, code int, cause error) *ClientServerError { + return &ClientServerError{Msg: msg, Code: code, Cause: cause} +} From 32610d71a3838c9af5a379f97b92914e3dc6b0ba Mon Sep 17 00:00:00 2001 From: Wenxin Du <117315983+duwenxin99@users.noreply.github.com> Date: Thu, 12 Feb 2026 11:42:27 -0500 Subject: [PATCH 3/9] refactor(server): standardize tool error handling and status code mapping (#2402) - Detect errors and return error codes accordingly in the tool call handler functions. - Replace the old `util.ErrUnauthorized` with the new Toolbox error type. --- internal/server/api.go | 86 ++- internal/server/mcp.go | 12 +- internal/server/mcp/v20241105/method.go | 70 ++- internal/server/mcp/v20250326/method.go | 70 ++- internal/server/mcp/v20250618/method.go | 69 ++- internal/server/mcp/v20251125/method.go | 69 ++- internal/server/mcp_test.go | 6 +- internal/server/mocks.go | 3 +- internal/sources/alloydbadmin/alloydbadmin.go | 6 +- .../alloydbcreatecluster.go | 25 +- .../alloydbcreateinstance.go | 24 +- .../alloydbcreateuser/alloydbcreateuser.go | 24 +- .../alloydbgetcluster/alloydbgetcluster.go | 24 +- .../alloydbgetinstance/alloydbgetinstance.go | 28 +- .../alloydb/alloydbgetuser/alloydbgetuser.go | 28 +- .../alloydblistclusters.go | 18 +- .../alloydblistinstances.go | 20 +- .../alloydblistusers/alloydblistusers.go | 24 +- .../alloydbwaitforoperation.go | 20 +- internal/tools/alloydbainl/alloydbainl.go | 8 +- .../bigqueryanalyzecontribution.go | 45 +- .../bigqueryconversationalanalytics.go | 21 +- .../bigqueryexecutesql/bigqueryexecutesql.go | 41 +- .../bigqueryforecast/bigqueryforecast.go | 51 +- .../bigquerygetdatasetinfo.go | 18 +- .../bigquerygettableinfo.go | 18 +- .../bigquerylistdatasetids.go | 12 +- .../bigquerylisttableids.go | 17 +- .../bigquerysearchcatalog.go | 23 +- .../tools/bigquery/bigquerysql/bigquerysql.go | 30 +- internal/tools/bigtable/bigtable.go | 18 +- .../cassandra/cassandracql/cassandracql.go | 18 +- .../clickhouseexecutesql.go | 14 +- .../clickhouselistdatabases.go | 8 +- .../clickhouselisttables.go | 17 +- .../clickhouse/clickhousesql/clickhousesql.go | 16 +- internal/tools/cloudgda/cloudgda.go | 20 +- .../cloudhealthcarefhirfetchpage.go | 19 +- .../cloudhealthcarefhirpatienteverything.go | 26 +- .../cloudhealthcarefhirpatientsearch.go | 24 +- .../cloudhealthcaregetdataset.go | 14 +- .../cloudhealthcaregetdicomstore.go | 16 +- .../cloudhealthcaregetdicomstoremetrics.go | 16 +- .../cloudhealthcaregetfhirresource.go | 20 +- .../cloudhealthcaregetfhirstore.go | 16 +- .../cloudhealthcaregetfhirstoremetrics.go | 16 +- .../cloudhealthcarelistdicomstores.go | 14 +- .../cloudhealthcarelistfhirstores.go | 14 +- ...healthcareretrieverendereddicominstance.go | 24 +- .../cloudhealthcaresearchdicominstances.go | 22 +- .../cloudhealthcaresearchdicomseries.go | 20 +- .../cloudhealthcaresearchdicomstudies.go | 18 +- .../cloudloggingadminlistlognames.go | 16 +- .../cloudloggingadminlistresourcetypes.go | 14 +- .../cloudloggingadminquerylogs.go | 22 +- .../tools/cloudmonitoring/cloudmonitoring.go | 15 +- .../cloudsqlcloneinstance.go | 18 +- .../cloudsqlcreatebackup.go | 16 +- .../cloudsqlcreatedatabase.go | 18 +- .../cloudsqlcreateusers.go | 22 +- .../cloudsqlgetinstances.go | 16 +- .../cloudsqllistdatabases.go | 16 +- .../cloudsqllistinstances.go | 14 +- .../cloudsqlrestorebackup.go | 18 +- .../cloudsqlwaitforoperation.go | 18 +- .../cloudsqlmssqlcreateinstance.go | 24 +- .../cloudsqlmysqlcreateinstance.go | 24 +- .../cloudsqlpgcreateinstances.go | 24 +- .../cloudsqlpgupgradeprecheck.go | 22 +- .../cockroachdbexecutesql.go | 22 +- .../cockroachdblistschemas.go | 16 +- .../cockroachdblisttables.go | 20 +- .../cockroachdbsql/cockroachdbsql.go | 20 +- internal/tools/couchbase/couchbase.go | 20 +- .../dataformcompilelocal.go | 8 +- .../dataplexlookupentry.go | 14 +- .../dataplexsearchaspecttypes.go | 27 +- .../dataplexsearchentries.go | 27 +- internal/tools/dgraph/dgraph.go | 12 +- .../elasticsearchesql/elasticsearchesql.go | 14 +- .../firebirdexecutesql/firebirdexecutesql.go | 20 +- .../tools/firebird/firebirdsql/firebirdsql.go | 17 +- .../firestoreadddocuments.go | 26 +- .../firestoredeletedocuments.go | 26 +- .../firestoregetdocuments.go | 26 +- .../firestoregetrules/firestoregetrules.go | 12 +- .../firestorelistcollections.go | 18 +- .../firestorequery/firestorequery.go | 32 +- .../firestorequerycollection.go | 20 +- .../firestoreupdatedocument.go | 36 +- .../firestorevalidaterules.go | 14 +- internal/tools/http/http.go | 28 +- .../lookeradddashboardelement.go | 45 +- .../lookeradddashboardfilter.go | 57 +- .../lookerconversationalanalytics.go | 10 +- .../lookercreateprojectfile.go | 16 +- .../lookerdeleteprojectfile.go | 14 +- .../looker/lookerdevmode/lookerdevmode.go | 16 +- .../lookergenerateembedurl.go | 13 +- .../lookergetconnectiondatabases.go | 13 +- .../lookergetconnections.go | 13 +- .../lookergetconnectionschemas.go | 12 +- .../lookergetconnectiontablecolumns.go | 19 +- .../lookergetconnectiontables.go | 15 +- .../lookergetdashboards.go | 11 +- .../lookergetdimensions.go | 17 +- .../lookergetexplores/lookergetexplores.go | 13 +- .../lookergetfilters/lookergetfilters.go | 17 +- .../looker/lookergetlooks/lookergetlooks.go | 21 +- .../lookergetmeasures/lookergetmeasures.go | 17 +- .../looker/lookergetmodels/lookergetmodels.go | 11 +- .../lookergetparameters.go | 19 +- .../lookergetprojectfile.go | 15 +- .../lookergetprojectfiles.go | 13 +- .../lookergetprojects/lookergetprojects.go | 11 +- .../lookerhealthanalyze.go | 51 +- .../lookerhealthpulse/lookerhealthpulse.go | 13 +- .../lookerhealthvacuum/lookerhealthvacuum.go | 23 +- .../lookermakedashboard.go | 19 +- .../looker/lookermakelook/lookermakelook.go | 23 +- .../tools/looker/lookerquery/lookerquery.go | 15 +- .../looker/lookerquerysql/lookerquerysql.go | 13 +- .../looker/lookerqueryurl/lookerqueryurl.go | 18 +- .../lookerrundashboard/lookerrundashboard.go | 11 +- .../looker/lookerrunlook/lookerrunlook.go | 15 +- .../lookerupdateprojectfile.go | 16 +- .../lookervalidateproject.go | 13 +- .../mindsdbexecutesql/mindsdbexecutesql.go | 19 +- .../tools/mindsdb/mindsdbsql/mindsdbsql.go | 19 +- .../mongodbaggregate/mongodbaggregate.go | 14 +- .../mongodbdeletemany/mongodbdeletemany.go | 14 +- .../mongodbdeleteone/mongodbdeleteone.go | 16 +- .../tools/mongodb/mongodbfind/mongodbfind.go | 17 +- .../mongodb/mongodbfindone/mongodbfindone.go | 18 +- .../mongodbinsertmany/mongodbinsertmany.go | 17 +- .../mongodbinsertone/mongodbinsertone.go | 17 +- .../mongodbupdatemany/mongodbupdatemany.go | 16 +- .../mongodbupdateone/mongodbupdateone.go | 16 +- .../mssql/mssqlexecutesql/mssqlexecutesql.go | 19 +- .../mssql/mssqllisttables/mssqllisttables.go | 504 +++++++++--------- internal/tools/mssql/mssqlsql/mssqlsql.go | 16 +- .../mysql/mysqlexecutesql/mysqlexecutesql.go | 19 +- .../mysqlgetqueryplan/mysqlgetqueryplan.go | 25 +- .../mysqllistactivequeries.go | 129 ++--- .../mysqllisttablefragmentation.go | 59 +- .../mysql/mysqllisttables/mysqllisttables.go | 14 +- .../mysqllisttablesmissinguniqueindexes.go | 57 +- internal/tools/mysql/mysqlsql/mysqlsql.go | 16 +- .../tools/neo4j/neo4jcypher/neo4jcypher.go | 12 +- .../neo4jexecutecypher/neo4jexecutecypher.go | 18 +- .../tools/neo4j/neo4jschema/neo4jschema.go | 32 +- .../oceanbaseexecutesql.go | 14 +- .../oceanbase/oceanbasesql/oceanbasesql.go | 16 +- .../oracleexecutesql/oracleexecutesql.go | 15 +- internal/tools/oracle/oraclesql/oraclesql.go | 16 +- .../postgresdatabaseoverview.go | 25 +- .../postgresexecutesql/postgresexecutesql.go | 19 +- .../postgresgetcolumncardinality.go | 19 +- .../postgreslistactivequeries.go | 62 +-- .../postgreslistavailableextensions.go | 29 +- .../postgreslistdatabasestats.go | 132 ++--- .../postgreslistindexes.go | 103 ++-- .../postgreslistinstalledextensions.go | 52 +- .../postgreslistlocks/postgreslistlocks.go | 19 +- .../postgreslistpgsettings.go | 27 +- .../postgreslistpublicationtables.go | 71 +-- .../postgreslistquerystats.go | 19 +- .../postgreslistroles/postgreslistroles.go | 95 ++-- .../postgreslistschemas.go | 53 +- .../postgreslistsequences.go | 49 +- .../postgresliststoredprocedure.go | 15 +- .../postgreslisttables/postgreslisttables.go | 166 +++--- .../postgreslisttablespaces.go | 65 +-- .../postgreslisttablestats.go | 35 +- .../postgreslisttriggers.go | 103 ++-- .../postgreslistviews/postgreslistviews.go | 53 +- .../postgreslongrunningtransactions.go | 19 +- .../postgresreplicationstats.go | 25 +- .../tools/postgres/postgressql/postgressql.go | 19 +- internal/tools/redis/redis.go | 14 +- .../tools/serverlessspark/createbatch/tool.go | 24 +- .../serverlesssparkcancelbatch.go | 18 +- .../serverlesssparkgetbatch.go | 17 +- .../serverlesssparklistbatches.go | 28 +- .../singlestoreexecutesql.go | 20 +- .../singlestoresql/singlestoresql.go | 16 +- .../snowflakeexecutesql.go | 19 +- .../snowflake/snowflakesql/snowflakesql.go | 16 +- .../spannerexecutesql/spannerexecutesql.go | 15 +- .../spannerlistgraphs/spannerlistgraphs.go | 14 +- .../spannerlisttables/spannerlisttables.go | 24 +- .../tools/spanner/spannersql/spannersql.go | 31 +- .../sqliteexecutesql/sqliteexecutesql.go | 24 +- internal/tools/sqlite/sqlitesql/sqlitesql.go | 16 +- .../tidb/tidbexecutesql/tidbexecutesql.go | 20 +- internal/tools/tidb/tidbsql/tidbsql.go | 16 +- internal/tools/tools.go | 5 +- .../trino/trinoexecutesql/trinoexecutesql.go | 14 +- internal/tools/trino/trinosql/trinosql.go | 16 +- internal/tools/utility/wait/wait.go | 7 +- internal/tools/valkey/valkey.go | 14 +- internal/tools/yugabytedbsql/yugabytedbsql.go | 16 +- internal/util/errors.go | 65 ++- internal/util/parameters/parameters.go | 9 +- internal/util/util.go | 3 - tests/alloydb/alloydb_integration_test.go | 181 +++---- .../alloydb_wait_for_operation_test.go | 63 +-- tests/bigquery/bigquery_integration_test.go | 80 +-- tests/bigtable/bigtable_integration_test.go | 2 +- tests/cassandra/cassandra_integration_test.go | 2 +- .../clickhouse/clickhouse_integration_test.go | 26 +- .../cloud_healthcare_integration_test.go | 61 ++- .../cloud_logging_admin_integration_test.go | 4 +- .../cloudsql/cloud_sql_clone_instance_test.go | 9 +- .../cloudsql/cloud_sql_create_backup_test.go | 9 +- .../cloud_sql_create_database_test.go | 9 +- tests/cloudsql/cloud_sql_create_users_test.go | 9 +- .../cloudsql/cloud_sql_list_databases_test.go | 27 +- .../cloudsql/cloud_sql_restore_backup_test.go | 49 +- .../cloudsql_wait_for_operation_test.go | 8 +- ..._mssql_create_instance_integration_test.go | 9 +- ..._mysql_create_instance_integration_test.go | 9 +- .../cloud_sql_pg_create_instances_test.go | 9 +- .../cloud_sql_pg_upgrade_precheck_test.go | 25 +- tests/common.go | 8 +- tests/couchbase/couchbase_integration_test.go | 2 +- tests/dataform/dataform_integration_test.go | 6 +- tests/dataplex/dataplex_integration_test.go | 49 +- tests/firebird/firebird_integration_test.go | 2 +- tests/http/http_integration_test.go | 79 ++- tests/mariadb/mariadb_integration_test.go | 2 +- tests/mongodb/mongodb_integration_test.go | 5 +- tests/neo4j/neo4j_integration_test.go | 42 +- tests/oceanbase/oceanbase_integration_test.go | 2 +- tests/oracle/oracle_integration_test.go | 2 +- .../serverless_spark_integration_test.go | 56 +- .../singlestore_integration_test.go | 2 +- tests/snowflake/snowflake_integration_test.go | 2 +- tests/sqlite/sqlite_integration_test.go | 6 +- tests/tidb/tidb_integration_test.go | 2 +- tests/tool.go | 74 ++- tests/trino/trino_integration_test.go | 2 +- 242 files changed, 3863 insertions(+), 2720 deletions(-) diff --git a/internal/server/api.go b/internal/server/api.go index b5de3ec0a5..c992051269 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -19,7 +19,6 @@ import ( "errors" "fmt" "net/http" - "strings" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" @@ -216,7 +215,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) { // Check if any of the specified auth services is verified isAuthorized := tool.Authorized(verifiedAuthServices) if !isAuthorized { - err = fmt.Errorf("tool invocation not authorized. Please make sure your specify correct auth headers") + err = fmt.Errorf("tool invocation not authorized. Please make sure you specify correct auth headers") s.logger.DebugContext(ctx, err.Error()) _ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized)) return @@ -234,15 +233,28 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) { params, err := parameters.ParseParams(tool.GetParameters(), data, claimsFromAuth) if err != nil { - // If auth error, return 401 - if errors.Is(err, util.ErrUnauthorized) { - s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err)) + var clientServerErr *util.ClientServerError + + // Return 401 Authentication errors + if errors.As(err, &clientServerErr) && clientServerErr.Code == http.StatusUnauthorized { + s.logger.DebugContext(ctx, fmt.Sprintf("auth error: %v", err)) _ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized)) return } - err = fmt.Errorf("provided parameters were invalid: %w", err) - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) + + var agentErr *util.AgentError + if errors.As(err, &agentErr) { + s.logger.DebugContext(ctx, fmt.Sprintf("agent validation error: %v", err)) + errMap := map[string]string{"error": err.Error()} + errMarshal, _ := json.Marshal(errMap) + + _ = render.Render(w, r, &resultResponse{Result: string(errMarshal)}) + return + } + + // Return 500 if it's a specific ClientServerError that isn't a 401, or any other unexpected error + s.logger.ErrorContext(ctx, fmt.Sprintf("internal server error: %v", err)) + _ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError)) return } s.logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params)) @@ -259,34 +271,50 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) { // Determine what error to return to the users. if err != nil { - errStr := err.Error() - var statusCode int + var tbErr util.ToolboxError - // Upstream API auth error propagation - switch { - case strings.Contains(errStr, "Error 401"): - statusCode = http.StatusUnauthorized - case strings.Contains(errStr, "Error 403"): - statusCode = http.StatusForbidden - } + if errors.As(err, &tbErr) { + switch tbErr.Category() { + case util.CategoryAgent: + // Agent Errors -> 200 OK + s.logger.DebugContext(ctx, fmt.Sprintf("Tool invocation agent error: %v", err)) + res = map[string]string{ + "error": err.Error(), + } - if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden { - if clientAuth { - // Propagate the original 401/403 error. - s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err)) + case util.CategoryServer: + // Server Errors -> Check the specific code inside + var clientServerErr *util.ClientServerError + statusCode := http.StatusInternalServerError // Default to 500 + + if errors.As(err, &clientServerErr) { + if clientServerErr.Code != 0 { + statusCode = clientServerErr.Code + } + } + + // Process auth error + if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden { + if clientAuth { + // Token error, pass through 401/403 + s.logger.DebugContext(ctx, fmt.Sprintf("Client credentials lack authorization: %v", err)) + _ = render.Render(w, r, newErrResponse(err, statusCode)) + return + } + // ADC/Config error, return 500 + statusCode = http.StatusInternalServerError + } + + s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation server error: %v", err)) _ = render.Render(w, r, newErrResponse(err, statusCode)) return } - // ADC lacking permission or credentials configuration error. - internalErr := fmt.Errorf("unexpected auth error occured during Tool invocation: %w", err) - s.logger.ErrorContext(ctx, internalErr.Error()) - _ = render.Render(w, r, newErrResponse(internalErr, http.StatusInternalServerError)) + } else { + // Unknown error -> 500 + s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation unknown error: %v", err)) + _ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError)) return } - err = fmt.Errorf("error while invoking tool: %w", err) - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) - return } resMarshal, err := json.Marshal(res) diff --git a/internal/server/mcp.go b/internal/server/mcp.go index aecd2454f2..3adac31ab7 100644 --- a/internal/server/mcp.go +++ b/internal/server/mcp.go @@ -23,7 +23,6 @@ import ( "fmt" "io" "net/http" - "strings" "sync" "time" @@ -444,15 +443,12 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) { code := rpcResponse.Error.Code switch code { case jsonrpc.INTERNAL_ERROR: + // Map Internal RPC Error (-32603) to HTTP 500 w.WriteHeader(http.StatusInternalServerError) case jsonrpc.INVALID_REQUEST: - errStr := err.Error() - if errors.Is(err, util.ErrUnauthorized) { - w.WriteHeader(http.StatusUnauthorized) - } else if strings.Contains(errStr, "Error 401") { - w.WriteHeader(http.StatusUnauthorized) - } else if strings.Contains(errStr, "Error 403") { - w.WriteHeader(http.StatusForbidden) + var clientServerErr *util.ClientServerError + if errors.As(err, &clientServerErr) { + w.WriteHeader(clientServerErr.Code) } } } diff --git a/internal/server/mcp/v20241105/method.go b/internal/server/mcp/v20241105/method.go index afcdd504ea..0dd6943734 100644 --- a/internal/server/mcp/v20241105/method.go +++ b/internal/server/mcp/v20241105/method.go @@ -21,7 +21,6 @@ import ( "errors" "fmt" "net/http" - "strings" "github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" @@ -124,7 +123,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re } if clientAuth { if accessToken == "" { - return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized + err := util.NewClientServerError( + "missing access token in the 'Authorization' header", + http.StatusUnauthorized, + nil, + ) + return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err } } @@ -172,7 +176,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re // Check if any of the specified auth services is verified isAuthorized := tool.Authorized(verifiedAuthServices) if !isAuthorized { - err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized) + err = util.NewClientServerError( + "unauthorized Tool call: Please make sure you specify correct auth headers", + http.StatusUnauthorized, + nil, + ) return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err } logger.DebugContext(ctx, "tool invocation authorized") @@ -194,30 +202,44 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re // run tool invocation and generate response. results, err := tool.Invoke(ctx, resourceMgr, params, accessToken) if err != nil { - errStr := err.Error() - // Missing authService tokens. - if errors.Is(err, util.ErrUnauthorized) { - return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err - } - // Upstream auth error - if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") { - if clientAuth { - // Error with client credentials should pass down to the client - return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err + var tbErr util.ToolboxError + + if errors.As(err, &tbErr) { + switch tbErr.Category() { + case util.CategoryAgent: + // MCP - Tool execution error + // Return SUCCESS but with IsError: true + text := TextContent{ + Type: "text", + Text: err.Error(), + } + return jsonrpc.JSONRPCResponse{ + Jsonrpc: jsonrpc.JSONRPC_VERSION, + Id: id, + Result: CallToolResult{Content: []TextContent{text}, IsError: true}, + }, nil + + case util.CategoryServer: + // MCP Spec - Protocol error + // Return JSON-RPC ERROR + var clientServerErr *util.ClientServerError + rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603) + + if errors.As(err, &clientServerErr) { + if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden { + if clientAuth { + rpcCode = jsonrpc.INVALID_REQUEST + } else { + rpcCode = jsonrpc.INTERNAL_ERROR + } + } + } + return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err } - // Auth error with ADC should raise internal 500 error + } else { + // Unknown error -> 500 return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err } - - text := TextContent{ - Type: "text", - Text: err.Error(), - } - return jsonrpc.JSONRPCResponse{ - Jsonrpc: jsonrpc.JSONRPC_VERSION, - Id: id, - Result: CallToolResult{Content: []TextContent{text}, IsError: true}, - }, nil } content := make([]TextContent, 0) diff --git a/internal/server/mcp/v20250326/method.go b/internal/server/mcp/v20250326/method.go index 15798a2c07..22183d45d9 100644 --- a/internal/server/mcp/v20250326/method.go +++ b/internal/server/mcp/v20250326/method.go @@ -21,7 +21,6 @@ import ( "errors" "fmt" "net/http" - "strings" "github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" @@ -124,7 +123,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re } if clientAuth { if accessToken == "" { - return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized + err := util.NewClientServerError( + "missing access token in the 'Authorization' header", + http.StatusUnauthorized, + nil, + ) + return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err } } @@ -172,7 +176,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re // Check if any of the specified auth services is verified isAuthorized := tool.Authorized(verifiedAuthServices) if !isAuthorized { - err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized) + err = util.NewClientServerError( + "unauthorized Tool call: Please make sure you specify correct auth headers", + http.StatusUnauthorized, + nil, + ) return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err } logger.DebugContext(ctx, "tool invocation authorized") @@ -194,31 +202,45 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re // run tool invocation and generate response. results, err := tool.Invoke(ctx, resourceMgr, params, accessToken) if err != nil { - errStr := err.Error() - // Missing authService tokens. - if errors.Is(err, util.ErrUnauthorized) { - return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err - } - // Upstream auth error - if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") { - if clientAuth { - // Error with client credentials should pass down to the client - return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err + var tbErr util.ToolboxError + + if errors.As(err, &tbErr) { + switch tbErr.Category() { + case util.CategoryAgent: + // MCP - Tool execution error + // Return SUCCESS but with IsError: true + text := TextContent{ + Type: "text", + Text: err.Error(), + } + return jsonrpc.JSONRPCResponse{ + Jsonrpc: jsonrpc.JSONRPC_VERSION, + Id: id, + Result: CallToolResult{Content: []TextContent{text}, IsError: true}, + }, nil + + case util.CategoryServer: + // MCP Spec - Protocol error + // Return JSON-RPC ERROR + var clientServerErr *util.ClientServerError + rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603) + + if errors.As(err, &clientServerErr) { + if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden { + if clientAuth { + rpcCode = jsonrpc.INVALID_REQUEST + } else { + rpcCode = jsonrpc.INTERNAL_ERROR + } + } + } + return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err } - // Auth error with ADC should raise internal 500 error + } else { + // Unknown error -> 500 return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err } - text := TextContent{ - Type: "text", - Text: err.Error(), - } - return jsonrpc.JSONRPCResponse{ - Jsonrpc: jsonrpc.JSONRPC_VERSION, - Id: id, - Result: CallToolResult{Content: []TextContent{text}, IsError: true}, - }, nil } - content := make([]TextContent, 0) sliceRes, ok := results.([]any) diff --git a/internal/server/mcp/v20250618/method.go b/internal/server/mcp/v20250618/method.go index 4a0ecaa4e0..24312d2da9 100644 --- a/internal/server/mcp/v20250618/method.go +++ b/internal/server/mcp/v20250618/method.go @@ -21,7 +21,6 @@ import ( "errors" "fmt" "net/http" - "strings" "github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" @@ -117,7 +116,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re } if clientAuth { if accessToken == "" { - return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized + err := util.NewClientServerError( + "missing access token in the 'Authorization' header", + http.StatusUnauthorized, + nil, + ) + return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err } } @@ -165,7 +169,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re // Check if any of the specified auth services is verified isAuthorized := tool.Authorized(verifiedAuthServices) if !isAuthorized { - err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized) + err = util.NewClientServerError( + "unauthorized Tool call: Please make sure you specify correct auth headers", + http.StatusUnauthorized, + nil, + ) return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err } logger.DebugContext(ctx, "tool invocation authorized") @@ -187,29 +195,44 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re // run tool invocation and generate response. results, err := tool.Invoke(ctx, resourceMgr, params, accessToken) if err != nil { - errStr := err.Error() - // Missing authService tokens. - if errors.Is(err, util.ErrUnauthorized) { - return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err - } - // Upstream auth error - if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") { - if clientAuth { - // Error with client credentials should pass down to the client - return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err + var tbErr util.ToolboxError + + if errors.As(err, &tbErr) { + switch tbErr.Category() { + case util.CategoryAgent: + // MCP - Tool execution error + // Return SUCCESS but with IsError: true + text := TextContent{ + Type: "text", + Text: err.Error(), + } + return jsonrpc.JSONRPCResponse{ + Jsonrpc: jsonrpc.JSONRPC_VERSION, + Id: id, + Result: CallToolResult{Content: []TextContent{text}, IsError: true}, + }, nil + + case util.CategoryServer: + // MCP Spec - Protocol error + // Return JSON-RPC ERROR + var clientServerErr *util.ClientServerError + rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603) + + if errors.As(err, &clientServerErr) { + if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden { + if clientAuth { + rpcCode = jsonrpc.INVALID_REQUEST + } else { + rpcCode = jsonrpc.INTERNAL_ERROR + } + } + } + return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err } - // Auth error with ADC should raise internal 500 error + } else { + // Unknown error -> 500 return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err } - text := TextContent{ - Type: "text", - Text: err.Error(), - } - return jsonrpc.JSONRPCResponse{ - Jsonrpc: jsonrpc.JSONRPC_VERSION, - Id: id, - Result: CallToolResult{Content: []TextContent{text}, IsError: true}, - }, nil } content := make([]TextContent, 0) diff --git a/internal/server/mcp/v20251125/method.go b/internal/server/mcp/v20251125/method.go index 51d67d097c..408fd0303c 100644 --- a/internal/server/mcp/v20251125/method.go +++ b/internal/server/mcp/v20251125/method.go @@ -21,7 +21,6 @@ import ( "errors" "fmt" "net/http" - "strings" "github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" @@ -117,7 +116,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re } if clientAuth { if accessToken == "" { - return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized + err := util.NewClientServerError( + "missing access token in the 'Authorization' header", + http.StatusUnauthorized, + nil, + ) + return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err } } @@ -165,7 +169,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re // Check if any of the specified auth services is verified isAuthorized := tool.Authorized(verifiedAuthServices) if !isAuthorized { - err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized) + err = util.NewClientServerError( + "unauthorized Tool call: Please make sure you specify correct auth headers", + http.StatusUnauthorized, + nil, + ) return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err } logger.DebugContext(ctx, "tool invocation authorized") @@ -187,29 +195,44 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re // run tool invocation and generate response. results, err := tool.Invoke(ctx, resourceMgr, params, accessToken) if err != nil { - errStr := err.Error() - // Missing authService tokens. - if errors.Is(err, util.ErrUnauthorized) { - return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err - } - // Upstream auth error - if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") { - if clientAuth { - // Error with client credentials should pass down to the client - return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err + var tbErr util.ToolboxError + + if errors.As(err, &tbErr) { + switch tbErr.Category() { + case util.CategoryAgent: + // MCP - Tool execution error + // Return SUCCESS but with IsError: true + text := TextContent{ + Type: "text", + Text: err.Error(), + } + return jsonrpc.JSONRPCResponse{ + Jsonrpc: jsonrpc.JSONRPC_VERSION, + Id: id, + Result: CallToolResult{Content: []TextContent{text}, IsError: true}, + }, nil + + case util.CategoryServer: + // MCP Spec - Protocol error + // Return JSON-RPC ERROR + var clientServerErr *util.ClientServerError + rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603) + + if errors.As(err, &clientServerErr) { + if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden { + if clientAuth { + rpcCode = jsonrpc.INVALID_REQUEST + } else { + rpcCode = jsonrpc.INTERNAL_ERROR + } + } + } + return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err } - // Auth error with ADC should raise internal 500 error + } else { + // Unknown error -> 500 return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err } - text := TextContent{ - Type: "text", - Text: err.Error(), - } - return jsonrpc.JSONRPCResponse{ - Jsonrpc: jsonrpc.JSONRPC_VERSION, - Id: id, - Result: CallToolResult{Content: []TextContent{text}, IsError: true}, - }, nil } content := make([]TextContent, 0) diff --git a/internal/server/mcp_test.go b/internal/server/mcp_test.go index 0d50af2b24..bbfce7ad41 100644 --- a/internal/server/mcp_test.go +++ b/internal/server/mcp_test.go @@ -231,7 +231,7 @@ func TestMcpEndpointWithoutInitialized(t *testing.T) { "id": "tools-call-tool4", "error": map[string]any{ "code": -32600.0, - "message": "unauthorized Tool call: Please make sure your specify correct auth headers: unauthorized", + "message": "unauthorized Tool call: Please make sure you specify correct auth headers", }, }, }, @@ -320,7 +320,7 @@ func TestMcpEndpointWithoutInitialized(t *testing.T) { Params: map[string]any{ "name": "prompt2", "arguments": map[string]any{ - "arg1": 42, // prompt2 expects a string, we send a number + "arg1": 42, }, }, }, @@ -834,7 +834,7 @@ func TestMcpEndpoint(t *testing.T) { "id": "tools-call-tool4", "error": map[string]any{ "code": -32600.0, - "message": "unauthorized Tool call: Please make sure your specify correct auth headers: unauthorized", + "message": "unauthorized Tool call: Please make sure you specify correct auth headers", }, }, }, diff --git a/internal/server/mocks.go b/internal/server/mocks.go index 60aa4f6212..56e458110b 100644 --- a/internal/server/mocks.go +++ b/internal/server/mocks.go @@ -21,6 +21,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -34,7 +35,7 @@ type MockTool struct { requiresClientAuthrorization bool } -func (t MockTool) Invoke(context.Context, tools.SourceProvider, parameters.ParamValues, tools.AccessToken) (any, error) { +func (t MockTool) Invoke(context.Context, tools.SourceProvider, parameters.ParamValues, tools.AccessToken) (any, util.ToolboxError) { mock := []any{t.Name} return mock, nil } diff --git a/internal/sources/alloydbadmin/alloydbadmin.go b/internal/sources/alloydbadmin/alloydbadmin.go index 6a9938d936..2761d96644 100644 --- a/internal/sources/alloydbadmin/alloydbadmin.go +++ b/internal/sources/alloydbadmin/alloydbadmin.go @@ -361,7 +361,11 @@ func (s *Source) GetOperations(ctx context.Context, project, location, operation } } - return string(opBytes), nil + var result any + if err := json.Unmarshal(opBytes, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal operation bytes: %w", err) + } + return result, nil } logger.DebugContext(ctx, fmt.Sprintf("Operation not complete, retrying in %v\n", delay)) } diff --git a/internal/tools/alloydb/alloydbcreatecluster/alloydbcreatecluster.go b/internal/tools/alloydb/alloydbcreatecluster/alloydbcreatecluster.go index db59b982a8..875d21aca5 100644 --- a/internal/tools/alloydb/alloydbcreatecluster/alloydbcreatecluster.go +++ b/internal/tools/alloydb/alloydbcreatecluster/alloydbcreatecluster.go @@ -17,11 +17,13 @@ package alloydbcreatecluster import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -122,44 +124,49 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok || project == "" { - return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a non-empty string") + return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a non-empty string", nil) } location, ok := paramsMap["location"].(string) if !ok { - return nil, fmt.Errorf("invalid 'location' parameter; expected a string") + return nil, util.NewAgentError("invalid 'location' parameter; expected a string", nil) } clusterID, ok := paramsMap["cluster"].(string) if !ok || clusterID == "" { - return nil, fmt.Errorf("invalid or missing 'cluster' parameter; expected a non-empty string") + return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a non-empty string", nil) } password, ok := paramsMap["password"].(string) if !ok || password == "" { - return nil, fmt.Errorf("invalid or missing 'password' parameter; expected a non-empty string") + return nil, util.NewAgentError("invalid or missing 'password' parameter; expected a non-empty string", nil) } network, ok := paramsMap["network"].(string) if !ok { - return nil, fmt.Errorf("invalid 'network' parameter; expected a string") + return nil, util.NewAgentError("invalid 'network' parameter; expected a string", nil) } user, ok := paramsMap["user"].(string) if !ok { - return nil, fmt.Errorf("invalid 'user' parameter; expected a string") + return nil, util.NewAgentError("invalid 'user' parameter; expected a string", nil) + } + resp, err := source.CreateCluster(ctx, project, location, network, user, password, clusterID, string(accessToken)) + + if err != nil { + return nil, util.ProcessGcpError(err) } - return source.CreateCluster(ctx, project, location, network, user, password, clusterID, string(accessToken)) + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/alloydb/alloydbcreateinstance/alloydbcreateinstance.go b/internal/tools/alloydb/alloydbcreateinstance/alloydbcreateinstance.go index 8b0adc3646..ce98dbe44d 100644 --- a/internal/tools/alloydb/alloydbcreateinstance/alloydbcreateinstance.go +++ b/internal/tools/alloydb/alloydbcreateinstance/alloydbcreateinstance.go @@ -17,11 +17,13 @@ package alloydbcreateinstance import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -123,36 +125,36 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok || project == "" { - return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a non-empty string") + return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a non-empty string", nil) } location, ok := paramsMap["location"].(string) if !ok || location == "" { - return nil, fmt.Errorf("invalid or missing 'location' parameter; expected a non-empty string") + return nil, util.NewAgentError("invalid or missing 'location' parameter; expected a non-empty string", nil) } cluster, ok := paramsMap["cluster"].(string) if !ok || cluster == "" { - return nil, fmt.Errorf("invalid or missing 'cluster' parameter; expected a non-empty string") + return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a non-empty string", nil) } instanceID, ok := paramsMap["instance"].(string) if !ok || instanceID == "" { - return nil, fmt.Errorf("invalid or missing 'instance' parameter; expected a non-empty string") + return nil, util.NewAgentError("invalid or missing 'instance' parameter; expected a non-empty string", nil) } instanceType, ok := paramsMap["instanceType"].(string) if !ok || (instanceType != "READ_POOL" && instanceType != "PRIMARY") { - return nil, fmt.Errorf("invalid 'instanceType' parameter; expected 'PRIMARY' or 'READ_POOL'") + return nil, util.NewAgentError("invalid 'instanceType' parameter; expected 'PRIMARY' or 'READ_POOL'", nil) } displayName, _ := paramsMap["displayName"].(string) @@ -161,11 +163,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if instanceType == "READ_POOL" { nodeCount, ok = paramsMap["nodeCount"].(int) if !ok { - return nil, fmt.Errorf("invalid 'nodeCount' parameter; expected an integer for READ_POOL") + return nil, util.NewAgentError("invalid 'nodeCount' parameter; expected an integer for READ_POOL", nil) } } - return source.CreateInstance(ctx, project, location, cluster, instanceID, instanceType, displayName, nodeCount, string(accessToken)) + resp, err := source.CreateInstance(ctx, project, location, cluster, instanceID, instanceType, displayName, nodeCount, string(accessToken)) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/alloydb/alloydbcreateuser/alloydbcreateuser.go b/internal/tools/alloydb/alloydbcreateuser/alloydbcreateuser.go index f1c0cb7c64..4d59c1fcfc 100644 --- a/internal/tools/alloydb/alloydbcreateuser/alloydbcreateuser.go +++ b/internal/tools/alloydb/alloydbcreateuser/alloydbcreateuser.go @@ -17,11 +17,13 @@ package alloydbcreateuser import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -122,43 +124,43 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok || project == "" { - return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a non-empty string") + return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a non-empty string", nil) } location, ok := paramsMap["location"].(string) if !ok || location == "" { - return nil, fmt.Errorf("invalid or missing'location' parameter; expected a non-empty string") + return nil, util.NewAgentError("invalid or missing'location' parameter; expected a non-empty string", nil) } cluster, ok := paramsMap["cluster"].(string) if !ok || cluster == "" { - return nil, fmt.Errorf("invalid or missing 'cluster' parameter; expected a non-empty string") + return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a non-empty string", nil) } userID, ok := paramsMap["user"].(string) if !ok || userID == "" { - return nil, fmt.Errorf("invalid or missing 'user' parameter; expected a non-empty string") + return nil, util.NewAgentError("invalid or missing 'user' parameter; expected a non-empty string", nil) } userType, ok := paramsMap["userType"].(string) if !ok || (userType != "ALLOYDB_BUILT_IN" && userType != "ALLOYDB_IAM_USER") { - return nil, fmt.Errorf("invalid or missing 'userType' parameter; expected 'ALLOYDB_BUILT_IN' or 'ALLOYDB_IAM_USER'") + return nil, util.NewAgentError("invalid or missing 'userType' parameter; expected 'ALLOYDB_BUILT_IN' or 'ALLOYDB_IAM_USER'", nil) } var password string if userType == "ALLOYDB_BUILT_IN" { password, ok = paramsMap["password"].(string) if !ok || password == "" { - return nil, fmt.Errorf("password is required when userType is ALLOYDB_BUILT_IN") + return nil, util.NewAgentError("password is required when userType is ALLOYDB_BUILT_IN", nil) } } @@ -170,7 +172,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } } - return source.CreateUser(ctx, userType, password, roles, string(accessToken), project, location, cluster, userID) + resp, err := source.CreateUser(ctx, userType, password, roles, string(accessToken), project, location, cluster, userID) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/alloydb/alloydbgetcluster/alloydbgetcluster.go b/internal/tools/alloydb/alloydbgetcluster/alloydbgetcluster.go index d0dc9d7269..a0875fbe3e 100644 --- a/internal/tools/alloydb/alloydbgetcluster/alloydbgetcluster.go +++ b/internal/tools/alloydb/alloydbgetcluster/alloydbgetcluster.go @@ -17,11 +17,13 @@ package alloydbgetcluster import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -120,28 +122,32 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) - if !ok { - return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string") + if !ok || project == "" { + return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a string", nil) } location, ok := paramsMap["location"].(string) - if !ok { - return nil, fmt.Errorf("invalid 'location' parameter; expected a string") + if !ok || location == "" { + return nil, util.NewAgentError("invalid or missing 'location' parameter; expected a string", nil) } cluster, ok := paramsMap["cluster"].(string) - if !ok { - return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string") + if !ok || cluster == "" { + return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a string", nil) } - return source.GetCluster(ctx, project, location, cluster, string(accessToken)) + resp, err := source.GetCluster(ctx, project, location, cluster, string(accessToken)) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/alloydb/alloydbgetinstance/alloydbgetinstance.go b/internal/tools/alloydb/alloydbgetinstance/alloydbgetinstance.go index 569d7dda70..e0ceb1ab6c 100644 --- a/internal/tools/alloydb/alloydbgetinstance/alloydbgetinstance.go +++ b/internal/tools/alloydb/alloydbgetinstance/alloydbgetinstance.go @@ -17,11 +17,13 @@ package alloydbgetinstance import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -120,32 +122,36 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) - if !ok { - return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string") + if !ok || project == "" { + return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a string", nil) } location, ok := paramsMap["location"].(string) - if !ok { - return nil, fmt.Errorf("invalid 'location' parameter; expected a string") + if !ok || location == "" { + return nil, util.NewAgentError("invalid or missing 'location' parameter; expected a string", nil) } cluster, ok := paramsMap["cluster"].(string) - if !ok { - return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string") + if !ok || cluster == "" { + return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a string", nil) } instance, ok := paramsMap["instance"].(string) - if !ok { - return nil, fmt.Errorf("invalid 'instance' parameter; expected a string") + if !ok || instance == "" { + return nil, util.NewAgentError("invalid or missing 'instance' parameter; expected a string", nil) } - return source.GetInstance(ctx, project, location, cluster, instance, string(accessToken)) + resp, err := source.GetInstance(ctx, project, location, cluster, instance, string(accessToken)) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/alloydb/alloydbgetuser/alloydbgetuser.go b/internal/tools/alloydb/alloydbgetuser/alloydbgetuser.go index 9b9d532a6c..ae7986a846 100644 --- a/internal/tools/alloydb/alloydbgetuser/alloydbgetuser.go +++ b/internal/tools/alloydb/alloydbgetuser/alloydbgetuser.go @@ -17,11 +17,13 @@ package alloydbgetuser import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -120,32 +122,36 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) - if !ok { - return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string") + if !ok || project == "" { + return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a string", nil) } location, ok := paramsMap["location"].(string) - if !ok { - return nil, fmt.Errorf("invalid 'location' parameter; expected a string") + if !ok || location == "" { + return nil, util.NewAgentError("invalid or missing 'location' parameter; expected a string", nil) } cluster, ok := paramsMap["cluster"].(string) - if !ok { - return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string") + if !ok || cluster == "" { + return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a string", nil) } user, ok := paramsMap["user"].(string) - if !ok { - return nil, fmt.Errorf("invalid 'user' parameter; expected a string") + if !ok || user == "" { + return nil, util.NewAgentError("invalid or missing 'user' parameter; expected a string", nil) } - return source.GetUsers(ctx, project, location, cluster, user, string(accessToken)) + resp, err := source.GetUsers(ctx, project, location, cluster, user, string(accessToken)) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/alloydb/alloydblistclusters/alloydblistclusters.go b/internal/tools/alloydb/alloydblistclusters/alloydblistclusters.go index 0477d05d55..ee624f039f 100644 --- a/internal/tools/alloydb/alloydblistclusters/alloydblistclusters.go +++ b/internal/tools/alloydb/alloydblistclusters/alloydblistclusters.go @@ -17,11 +17,13 @@ package alloydblistclusters import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -118,24 +120,28 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) - if !ok { - return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string") + if !ok || project == "" { + return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a string", nil) } location, ok := paramsMap["location"].(string) if !ok { - return nil, fmt.Errorf("invalid 'location' parameter; expected a string") + return nil, util.NewAgentError("invalid 'location' parameter; expected a string", nil) } - return source.ListCluster(ctx, project, location, string(accessToken)) + resp, err := source.ListCluster(ctx, project, location, string(accessToken)) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/alloydb/alloydblistinstances/alloydblistinstances.go b/internal/tools/alloydb/alloydblistinstances/alloydblistinstances.go index 749bdd5ea4..86f0b3b21a 100644 --- a/internal/tools/alloydb/alloydblistinstances/alloydblistinstances.go +++ b/internal/tools/alloydb/alloydblistinstances/alloydblistinstances.go @@ -17,11 +17,13 @@ package alloydblistinstances import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -119,28 +121,32 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) - if !ok { - return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string") + if !ok || project == "" { + return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a string", nil) } location, ok := paramsMap["location"].(string) if !ok { - return nil, fmt.Errorf("invalid 'location' parameter; expected a string") + return nil, util.NewAgentError("invalid 'location' parameter; expected a string", nil) } cluster, ok := paramsMap["cluster"].(string) if !ok { - return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string") + return nil, util.NewAgentError("invalid 'cluster' parameter; expected a string", nil) } - return source.ListInstance(ctx, project, location, cluster, string(accessToken)) + resp, err := source.ListInstance(ctx, project, location, cluster, string(accessToken)) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/alloydb/alloydblistusers/alloydblistusers.go b/internal/tools/alloydb/alloydblistusers/alloydblistusers.go index cbcc1a545c..6987b6e82e 100644 --- a/internal/tools/alloydb/alloydblistusers/alloydblistusers.go +++ b/internal/tools/alloydb/alloydblistusers/alloydblistusers.go @@ -17,11 +17,13 @@ package alloydblistusers import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -119,28 +121,32 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) - if !ok { - return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string") + if !ok || project == "" { + return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a string", nil) } location, ok := paramsMap["location"].(string) - if !ok { - return nil, fmt.Errorf("invalid 'location' parameter; expected a string") + if !ok || location == "" { + return nil, util.NewAgentError("invalid or missing 'location' parameter; expected a string", nil) } cluster, ok := paramsMap["cluster"].(string) - if !ok { - return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string") + if !ok || cluster == "" { + return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a string", nil) } - return source.ListUsers(ctx, project, location, cluster, string(accessToken)) + resp, err := source.ListUsers(ctx, project, location, cluster, string(accessToken)) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/alloydb/alloydbwaitforoperation/alloydbwaitforoperation.go b/internal/tools/alloydb/alloydbwaitforoperation/alloydbwaitforoperation.go index 8f10fed7e3..05ca8b7780 100644 --- a/internal/tools/alloydb/alloydbwaitforoperation/alloydbwaitforoperation.go +++ b/internal/tools/alloydb/alloydbwaitforoperation/alloydbwaitforoperation.go @@ -24,6 +24,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -213,25 +214,25 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok { - return nil, fmt.Errorf("missing 'project' parameter") + return nil, util.NewAgentError("missing 'project' parameter", nil) } location, ok := paramsMap["location"].(string) if !ok { - return nil, fmt.Errorf("missing 'location' parameter") + return nil, util.NewAgentError("missing 'location' parameter", nil) } operation, ok := paramsMap["operation"].(string) if !ok { - return nil, fmt.Errorf("missing 'operation' parameter") + return nil, util.NewAgentError("missing 'operation' parameter", nil) } ctx, cancel := context.WithTimeout(ctx, 30*time.Minute) @@ -246,14 +247,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para for retries < maxRetries { select { case <-ctx.Done(): - return nil, fmt.Errorf("timed out waiting for operation: %w", ctx.Err()) + return nil, util.NewAgentError("timed out waiting for operation", ctx.Err()) default: } op, err := source.GetOperations(ctx, project, location, operation, alloyDBConnectionMessageTemplate, delay, string(accessToken)) if err != nil { - return nil, err - } else if op != nil { + return nil, util.ProcessGeneralError(err) + } + if op != nil { return op, nil } @@ -264,7 +266,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } retries++ } - return nil, fmt.Errorf("exceeded max retries waiting for operation") + return nil, util.NewAgentError("exceeded max retries waiting for operation", nil) } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/alloydbainl/alloydbainl.go b/internal/tools/alloydbainl/alloydbainl.go index 4a0b8b9ba8..98cf20870b 100644 --- a/internal/tools/alloydbainl/alloydbainl.go +++ b/internal/tools/alloydbainl/alloydbainl.go @@ -17,12 +17,14 @@ package alloydbainl import ( "context" "fmt" + "net/http" "strings" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -127,10 +129,10 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } sliceParams := params.AsSlice() @@ -143,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para resp, err := source.RunSQL(ctx, t.Statement, allParamValues) if err != nil { - return nil, fmt.Errorf("%w. Query: %v , Values: %v. Toolbox v0.19.0+ is only compatible with AlloyDB AI NL v1.0.3+. Please ensure that you are using the latest AlloyDB AI NL extension", err, t.Statement, allParamValues) + return nil, util.NewClientServerError(fmt.Sprintf("error running SQL query: %v. Query: %v , Values: %v. Toolbox v0.19.0+ is only compatible with AlloyDB AI NL v1.0.3+. Please ensure that you are using the latest AlloyDB AI NL extension", err, t.Statement, allParamValues), http.StatusBadRequest, err) } return resp, nil } diff --git a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go index e9758ba7a9..f8d453039b 100644 --- a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go +++ b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go @@ -17,6 +17,7 @@ package bigqueryanalyzecontribution import ( "context" "fmt" + "net/http" "strings" bigqueryapi "cloud.google.com/go/bigquery" @@ -27,6 +28,7 @@ import ( bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" "github.com/googleapis/genai-toolbox/internal/tools" bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" bigqueryrestapi "google.golang.org/api/bigquery/v2" ) @@ -154,21 +156,21 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke runs the contribution analysis. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() inputData, ok := paramsMap["input_data"].(string) if !ok { - return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to cast input_data parameter %s", paramsMap["input_data"]), nil) } bqClient, restService, err := source.RetrieveClientAndService(accessToken) if err != nil { - return nil, err + return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err) } modelID := fmt.Sprintf("contribution_analysis_model_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) @@ -186,7 +188,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } options = append(options, fmt.Sprintf("DIMENSION_ID_COLS = [%s]", strings.Join(strCols, ", "))) } else { - return nil, fmt.Errorf("unable to cast dimension_id_cols parameter %s", paramsMap["dimension_id_cols"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to cast dimension_id_cols parameter %s", paramsMap["dimension_id_cols"]), nil) } } if val, ok := paramsMap["top_k_insights_by_apriori_support"]; ok { @@ -195,7 +197,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if val, ok := paramsMap["pruning_method"].(string); ok { upperVal := strings.ToUpper(val) if upperVal != "NO_PRUNING" && upperVal != "PRUNE_REDUNDANT_INSIGHTS" { - return nil, fmt.Errorf("invalid pruning_method: %s", val) + return nil, util.NewAgentError(fmt.Sprintf("invalid pruning_method: %s", val), nil) } options = append(options, fmt.Sprintf("PRUNING_METHOD = '%s'", upperVal)) } @@ -207,7 +209,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para var connProps []*bigqueryapi.ConnectionProperty session, err := source.BigQuerySession()(ctx) if err != nil { - return nil, fmt.Errorf("failed to get BigQuery session: %w", err) + return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err) } if session != nil { connProps = []*bigqueryapi.ConnectionProperty{ @@ -216,22 +218,22 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps) if err != nil { - return nil, fmt.Errorf("query validation failed: %w", err) + return nil, util.ProcessGcpError(err) } statementType := dryRunJob.Statistics.Query.StatementType if statementType != "SELECT" { - return nil, fmt.Errorf("the 'input_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType) + return nil, util.NewAgentError(fmt.Sprintf("the 'input_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType), nil) } queryStats := dryRunJob.Statistics.Query if queryStats != nil { for _, tableRef := range queryStats.ReferencedTables { if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) { - return nil, fmt.Errorf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId) + return nil, util.NewAgentError(fmt.Sprintf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId), nil) } } } else { - return nil, fmt.Errorf("could not analyze query in input_data to validate against allowed datasets") + return nil, util.NewAgentError("could not analyze query in input_data to validate against allowed datasets", nil) } } inputDataSource = fmt.Sprintf("(%s)", inputData) @@ -245,10 +247,10 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para case 2: // dataset.table projectID, datasetID = source.BigQueryClient().Project(), parts[0] default: - return nil, fmt.Errorf("invalid table ID format for 'input_data': %q. Expected 'dataset.table' or 'project.dataset.table'", inputData) + return nil, util.NewAgentError(fmt.Sprintf("invalid table ID format for 'input_data': %q. Expected 'dataset.table' or 'project.dataset.table'", inputData), nil) } if !source.IsDatasetAllowed(projectID, datasetID) { - return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, inputData) + return nil, util.NewAgentError(fmt.Sprintf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, inputData), nil) } } inputDataSource = fmt.Sprintf("SELECT * FROM `%s`", inputData) @@ -268,7 +270,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Otherwise, a new session will be created by the first query. session, err := source.BigQuerySession()(ctx) if err != nil { - return nil, fmt.Errorf("failed to get BigQuery session: %w", err) + return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err) } if session != nil { @@ -281,15 +283,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } createModelJob, err := createModelQuery.Run(ctx) if err != nil { - return nil, fmt.Errorf("failed to start create model job: %w", err) + return nil, util.ProcessGcpError(err) } status, err := createModelJob.Wait(ctx) if err != nil { - return nil, fmt.Errorf("failed to wait for create model job: %w", err) + return nil, util.ProcessGcpError(err) } if err := status.Err(); err != nil { - return nil, fmt.Errorf("create model job failed: %w", err) + return nil, util.ProcessGcpError(err) } // Determine the session ID to use for subsequent queries. @@ -300,12 +302,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } else if status.Statistics != nil && status.Statistics.SessionInfo != nil { sessionID = status.Statistics.SessionInfo.SessionID } else { - return nil, fmt.Errorf("failed to get or create a BigQuery session ID") + return nil, util.NewClientServerError("failed to get or create a BigQuery session ID", http.StatusInternalServerError, nil) } getInsightsSQL := fmt.Sprintf("SELECT * FROM ML.GET_INSIGHTS(MODEL %s)", modelID) connProps := []*bigqueryapi.ConnectionProperty{{Key: "session_id", Value: sessionID}} - return source.RunSQL(ctx, bqClient, getInsightsSQL, "SELECT", nil, connProps) + + resp, err := source.RunSQL(ctx, bqClient, getInsightsSQL, "SELECT", nil, connProps) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go b/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go index a3b908b29d..196a08b51d 100644 --- a/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go +++ b/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go @@ -172,10 +172,10 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } var tokenStr string @@ -184,26 +184,26 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if source.UseClientAuthorization() { // Use client-side access token if accessToken == "" { - return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", util.ErrUnauthorized) + return nil, util.NewClientServerError("tool is configured for client OAuth but no token was provided in the request header", http.StatusUnauthorized, nil) } tokenStr, err = accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) } } else { // Get a token source for the Gemini Data Analytics API. tokenSource, err := source.BigQueryTokenSourceWithScope(ctx, nil) if err != nil { - return nil, fmt.Errorf("failed to get token source: %w", err) + return nil, util.NewClientServerError("failed to get token source", http.StatusInternalServerError, err) } // Use cloud-platform token source for Gemini Data Analytics API if tokenSource == nil { - return nil, fmt.Errorf("cloud-platform token source is missing") + return nil, util.NewClientServerError("cloud-platform token source is missing", http.StatusInternalServerError, nil) } token, err := tokenSource.Token() if err != nil { - return nil, fmt.Errorf("failed to get token from cloud-platform token source: %w", err) + return nil, util.NewClientServerError("failed to get token from cloud-platform token source", http.StatusInternalServerError, err) } tokenStr = token.AccessToken } @@ -218,14 +218,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para var tableRefs []BQTableReference if tableRefsJSON != "" { if err := json.Unmarshal([]byte(tableRefsJSON), &tableRefs); err != nil { - return nil, fmt.Errorf("failed to parse 'table_references' JSON string: %w", err) + return nil, util.NewAgentError("failed to parse 'table_references' JSON string", err) } } if len(source.BigQueryAllowedDatasets()) > 0 { for _, tableRef := range tableRefs { if !source.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) { - return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID) + return nil, util.NewAgentError(fmt.Sprintf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID), nil) } } } @@ -258,7 +258,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Call the streaming API response, err := getStream(caURL, payload, headers, source.GetMaxQueryResultRows()) if err != nil { - return nil, fmt.Errorf("failed to get response from conversational analytics API: %w", err) + // getStream wraps network errors or non-200 responses + return nil, util.NewClientServerError("failed to get response from conversational analytics API", http.StatusInternalServerError, err) } return response, nil diff --git a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go index e14cfea511..157740c1bb 100644 --- a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go +++ b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "strings" bigqueryapi "cloud.google.com/go/bigquery" @@ -152,25 +153,25 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { - return nil, fmt.Errorf("unable to cast sql parameter %s", paramsMap["sql"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to cast sql parameter %s", paramsMap["sql"]), nil) } dryRun, ok := paramsMap["dry_run"].(bool) if !ok { - return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to cast dry_run parameter %s", paramsMap["dry_run"]), nil) } bqClient, restService, err := source.RetrieveClientAndService(accessToken) if err != nil { - return nil, err + return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err) } var connProps []*bigqueryapi.ConnectionProperty @@ -178,7 +179,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if source.BigQueryWriteMode() == bigqueryds.WriteModeProtected { session, err = source.BigQuerySession()(ctx) if err != nil { - return nil, fmt.Errorf("failed to get BigQuery session for protected mode: %w", err) + return nil, util.NewClientServerError("failed to get BigQuery session for protected mode", http.StatusInternalServerError, err) } connProps = []*bigqueryapi.ConnectionProperty{ {Key: "session_id", Value: session.ID}, @@ -187,7 +188,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps) if err != nil { - return nil, fmt.Errorf("query validation failed: %w", err) + return nil, util.NewClientServerError("query validation failed", http.StatusInternalServerError, err) } statementType := dryRunJob.Statistics.Query.StatementType @@ -195,13 +196,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para switch source.BigQueryWriteMode() { case bigqueryds.WriteModeBlocked: if statementType != "SELECT" { - return nil, fmt.Errorf("write mode is 'blocked', only SELECT statements are allowed") + return nil, util.NewAgentError("write mode is 'blocked', only SELECT statements are allowed", nil) } case bigqueryds.WriteModeProtected: if dryRunJob.Configuration != nil && dryRunJob.Configuration.Query != nil { if dest := dryRunJob.Configuration.Query.DestinationTable; dest != nil && dest.DatasetId != session.DatasetID { - return nil, fmt.Errorf("protected write mode only supports SELECT statements, or write operations in the anonymous "+ - "dataset of a BigQuery session, but destination was %q", dest.DatasetId) + return nil, util.NewAgentError(fmt.Sprintf("protected write mode only supports SELECT statements, or write operations in the anonymous "+ + "dataset of a BigQuery session, but destination was %q", dest.DatasetId), nil) } } } @@ -209,11 +210,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if len(source.BigQueryAllowedDatasets()) > 0 { switch statementType { case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA": - return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType) + return nil, util.NewAgentError(fmt.Sprintf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType), nil) case "CREATE_FUNCTION", "CREATE_TABLE_FUNCTION", "CREATE_PROCEDURE": - return nil, fmt.Errorf("creating stored routines ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType) + return nil, util.NewAgentError(fmt.Sprintf("creating stored routines ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType), nil) case "CALL": - return nil, fmt.Errorf("calling stored procedures ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType) + return nil, util.NewAgentError(fmt.Sprintf("calling stored procedures ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType), nil) } // Use a map to avoid duplicate table names. @@ -244,7 +245,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para parsedTables, parseErr := bqutil.TableParser(sql, source.BigQueryClient().Project()) if parseErr != nil { // If parsing fails (e.g., EXECUTE IMMEDIATE), we cannot guarantee safety, so we must fail. - return nil, fmt.Errorf("could not parse tables from query to validate against allowed datasets: %w", parseErr) + return nil, util.NewAgentError("could not parse tables from query to validate against allowed datasets", parseErr) } tableNames = parsedTables } @@ -254,7 +255,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if len(parts) == 3 { projectID, datasetID := parts[0], parts[1] if !source.IsDatasetAllowed(projectID, datasetID) { - return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID) + return nil, util.NewAgentError(fmt.Sprintf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID), nil) } } } @@ -264,7 +265,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if dryRunJob != nil { jobJSON, err := json.MarshalIndent(dryRunJob, "", " ") if err != nil { - return nil, fmt.Errorf("failed to marshal dry run job to JSON: %w", err) + return nil, util.NewClientServerError("failed to marshal dry run job to JSON", http.StatusInternalServerError, err) } return string(jobJSON), nil } @@ -275,10 +276,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql)) - return source.RunSQL(ctx, bqClient, sql, statementType, nil, connProps) + resp, err := source.RunSQL(ctx, bqClient, sql, statementType, nil, connProps) + if err != nil { + return nil, util.NewClientServerError("error running sql", http.StatusInternalServerError, err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go index 72f244bd96..5f4c5ce1f6 100644 --- a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go +++ b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go @@ -17,6 +17,7 @@ package bigqueryforecast import ( "context" "fmt" + "net/http" "strings" bigqueryapi "cloud.google.com/go/bigquery" @@ -133,34 +134,34 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() historyData, ok := paramsMap["history_data"].(string) if !ok { - return nil, fmt.Errorf("unable to cast history_data parameter %v", paramsMap["history_data"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to cast history_data parameter %v", paramsMap["history_data"]), nil) } timestampCol, ok := paramsMap["timestamp_col"].(string) if !ok { - return nil, fmt.Errorf("unable to cast timestamp_col parameter %v", paramsMap["timestamp_col"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to cast timestamp_col parameter %v", paramsMap["timestamp_col"]), nil) } dataCol, ok := paramsMap["data_col"].(string) if !ok { - return nil, fmt.Errorf("unable to cast data_col parameter %v", paramsMap["data_col"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to cast data_col parameter %v", paramsMap["data_col"]), nil) } idColsRaw, ok := paramsMap["id_cols"].([]any) if !ok { - return nil, fmt.Errorf("unable to cast id_cols parameter %v", paramsMap["id_cols"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to cast id_cols parameter %v", paramsMap["id_cols"]), nil) } var idCols []string for _, v := range idColsRaw { s, ok := v.(string) if !ok { - return nil, fmt.Errorf("id_cols contains non-string value: %v", v) + return nil, util.NewAgentError(fmt.Sprintf("id_cols contains non-string value: %v", v), nil) } idCols = append(idCols, s) } @@ -169,13 +170,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if h, ok := paramsMap["horizon"].(float64); ok { horizon = int(h) } else { - return nil, fmt.Errorf("unable to cast horizon parameter %v", paramsMap["horizon"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to cast horizon parameter %v", paramsMap["horizon"]), nil) } } bqClient, restService, err := source.RetrieveClientAndService(accessToken) if err != nil { - return nil, err + return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err) } var historyDataSource string @@ -185,7 +186,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para var connProps []*bigqueryapi.ConnectionProperty session, err := source.BigQuerySession()(ctx) if err != nil { - return nil, fmt.Errorf("failed to get BigQuery session: %w", err) + return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err) } if session != nil { connProps = []*bigqueryapi.ConnectionProperty{ @@ -194,22 +195,22 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps) if err != nil { - return nil, fmt.Errorf("query validation failed: %w", err) + return nil, util.ProcessGcpError(err) } statementType := dryRunJob.Statistics.Query.StatementType if statementType != "SELECT" { - return nil, fmt.Errorf("the 'history_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType) + return nil, util.NewAgentError(fmt.Sprintf("the 'history_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType), nil) } queryStats := dryRunJob.Statistics.Query if queryStats != nil { for _, tableRef := range queryStats.ReferencedTables { if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) { - return nil, fmt.Errorf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId) + return nil, util.NewAgentError(fmt.Sprintf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId), nil) } } } else { - return nil, fmt.Errorf("could not analyze query in history_data to validate against allowed datasets") + return nil, util.NewAgentError("could not analyze query in history_data to validate against allowed datasets", nil) } } historyDataSource = fmt.Sprintf("(%s)", historyData) @@ -226,11 +227,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para projectID = source.BigQueryClient().Project() datasetID = parts[0] default: - return nil, fmt.Errorf("invalid table ID format for 'history_data': %q. Expected 'dataset.table' or 'project.dataset.table'", historyData) + return nil, util.NewAgentError(fmt.Sprintf("invalid table ID format for 'history_data': %q. Expected 'dataset.table' or 'project.dataset.table'", historyData), nil) } if !source.IsDatasetAllowed(projectID, datasetID) { - return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, historyData) + return nil, util.NewAgentError(fmt.Sprintf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, historyData), nil) } } historyDataSource = fmt.Sprintf("TABLE `%s`", historyData) @@ -243,15 +244,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sql := fmt.Sprintf(`SELECT * FROM AI.FORECAST( - %s, - data_col => '%s', - timestamp_col => '%s', - horizon => %d%s)`, + %s, + data_col => '%s', + timestamp_col => '%s', + horizon => %d%s)`, historyDataSource, dataCol, timestampCol, horizon, idColsArg) session, err := source.BigQuerySession()(ctx) if err != nil { - return nil, fmt.Errorf("failed to get BigQuery session: %w", err) + return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err) } var connProps []*bigqueryapi.ConnectionProperty if session != nil { @@ -264,11 +265,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql)) - return source.RunSQL(ctx, bqClient, sql, "SELECT", nil, connProps) + resp, err := source.RunSQL(ctx, bqClient, sql, "SELECT", nil, connProps) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go b/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go index b3844d20cd..36d97ddb0e 100644 --- a/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go +++ b/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go @@ -17,6 +17,7 @@ package bigquerygetdatasetinfo import ( "context" "fmt" + "net/http" bigqueryapi "cloud.google.com/go/bigquery" yaml "github.com/goccy/go-yaml" @@ -24,6 +25,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" bigqueryrestapi "google.golang.org/api/bigquery/v2" ) @@ -120,38 +122,38 @@ type Tool struct { func (t Tool) ToConfig() tools.ToolConfig { return t.Config } - -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } mapParams := params.AsMap() projectId, ok := mapParams[projectKey].(string) if !ok { - return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey) + // Updated: Use fmt.Sprintf for formatting, pass nil as cause + return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", projectKey), nil) } datasetId, ok := mapParams[datasetKey].(string) if !ok { - return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", datasetKey), nil) } bqClient, _, err := source.RetrieveClientAndService(accessToken) if err != nil { - return nil, err + return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err) } if !source.IsDatasetAllowed(projectId, datasetId) { - return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId) + return nil, util.NewAgentError(fmt.Sprintf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId), nil) } dsHandle := bqClient.DatasetInProject(projectId, datasetId) metadata, err := dsHandle.Metadata(ctx) if err != nil { - return nil, fmt.Errorf("failed to get metadata for dataset %s (in project %s): %w", datasetId, projectId, err) + return nil, util.ProcessGcpError(err) } return metadata, nil diff --git a/internal/tools/bigquery/bigquerygettableinfo/bigquerygettableinfo.go b/internal/tools/bigquery/bigquerygettableinfo/bigquerygettableinfo.go index b7131df89f..fcf1703b66 100644 --- a/internal/tools/bigquery/bigquerygettableinfo/bigquerygettableinfo.go +++ b/internal/tools/bigquery/bigquerygettableinfo/bigquerygettableinfo.go @@ -17,6 +17,7 @@ package bigquerygettableinfo import ( "context" "fmt" + "net/http" bigqueryapi "cloud.google.com/go/bigquery" yaml "github.com/goccy/go-yaml" @@ -24,6 +25,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" bigqueryrestapi "google.golang.org/api/bigquery/v2" ) @@ -125,35 +127,35 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } mapParams := params.AsMap() projectId, ok := mapParams[projectKey].(string) if !ok { - return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", projectKey), nil) } datasetId, ok := mapParams[datasetKey].(string) if !ok { - return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", datasetKey), nil) } tableId, ok := mapParams[tableKey].(string) if !ok { - return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", tableKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", tableKey), nil) } if !source.IsDatasetAllowed(projectId, datasetId) { - return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId) + return nil, util.NewAgentError(fmt.Sprintf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId), nil) } bqClient, _, err := source.RetrieveClientAndService(accessToken) if err != nil { - return nil, err + return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err) } dsHandle := bqClient.DatasetInProject(projectId, datasetId) @@ -161,7 +163,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para metadata, err := tableHandle.Metadata(ctx) if err != nil { - return nil, fmt.Errorf("failed to get metadata for table %s.%s.%s: %w", projectId, datasetId, tableId, err) + return nil, util.ProcessGcpError(err) } return metadata, nil diff --git a/internal/tools/bigquery/bigquerylistdatasetids/bigquerylistdatasetids.go b/internal/tools/bigquery/bigquerylistdatasetids/bigquerylistdatasetids.go index 186ad7be54..12d819c420 100644 --- a/internal/tools/bigquery/bigquerylistdatasetids/bigquerylistdatasetids.go +++ b/internal/tools/bigquery/bigquerylistdatasetids/bigquerylistdatasetids.go @@ -17,12 +17,14 @@ package bigquerylistdatasetids import ( "context" "fmt" + "net/http" bigqueryapi "cloud.google.com/go/bigquery" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" bigqueryrestapi "google.golang.org/api/bigquery/v2" "google.golang.org/api/iterator" @@ -120,10 +122,10 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } if len(source.BigQueryAllowedDatasets()) > 0 { @@ -132,12 +134,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para mapParams := params.AsMap() projectId, ok := mapParams[projectKey].(string) if !ok { - return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", projectKey), nil) } bqClient, _, err := source.RetrieveClientAndService(accessToken) if err != nil { - return nil, err + return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err) } datasetIterator := bqClient.Datasets(ctx) datasetIterator.ProjectID = projectId @@ -149,7 +151,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para break } if err != nil { - return nil, fmt.Errorf("unable to iterate through datasets: %w", err) + return nil, util.ProcessGcpError(err) } // Remove leading and trailing quotes diff --git a/internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go b/internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go index 4390a89961..f566759cea 100644 --- a/internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go +++ b/internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go @@ -17,6 +17,7 @@ package bigquerylisttableids import ( "context" "fmt" + "net/http" bigqueryapi "cloud.google.com/go/bigquery" yaml "github.com/goccy/go-yaml" @@ -24,6 +25,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" bigqueryrestapi "google.golang.org/api/bigquery/v2" "google.golang.org/api/iterator" @@ -123,31 +125,30 @@ type Tool struct { func (t Tool) ToConfig() tools.ToolConfig { return t.Config } - -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } mapParams := params.AsMap() projectId, ok := mapParams[projectKey].(string) if !ok { - return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", projectKey), nil) } datasetId, ok := mapParams[datasetKey].(string) if !ok { - return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", datasetKey), nil) } if !source.IsDatasetAllowed(projectId, datasetId) { - return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId) + return nil, util.NewAgentError(fmt.Sprintf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId), nil) } bqClient, _, err := source.RetrieveClientAndService(accessToken) if err != nil { - return nil, err + return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err) } dsHandle := bqClient.DatasetInProject(projectId, datasetId) @@ -160,7 +161,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para break } if err != nil { - return nil, fmt.Errorf("failed to iterate through tables in dataset %s.%s: %w", projectId, datasetId, err) + return nil, util.ProcessGcpError(err) } // Remove leading and trailing quotes diff --git a/internal/tools/bigquery/bigquerysearchcatalog/bigquerysearchcatalog.go b/internal/tools/bigquery/bigquerysearchcatalog/bigquerysearchcatalog.go index 323dbbebb1..3cb5393178 100644 --- a/internal/tools/bigquery/bigquerysearchcatalog/bigquerysearchcatalog.go +++ b/internal/tools/bigquery/bigquerysearchcatalog/bigquerysearchcatalog.go @@ -17,6 +17,7 @@ package bigquerysearchcatalog import ( "context" "fmt" + "net/http" "strings" dataplexapi "cloud.google.com/go/dataplex/apiv1" @@ -26,6 +27,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/iterator" ) @@ -186,28 +188,31 @@ func ExtractType(resourceString string) string { return typeMap[resourceString[lastIndex+1:]] } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() pageSize := int32(paramsMap["pageSize"].(int)) prompt, _ := paramsMap["prompt"].(string) + projectIdSlice, err := parameters.ConvertAnySliceToTyped(paramsMap["projectIds"].([]any), "string") if err != nil { - return nil, fmt.Errorf("can't convert projectIds to array of strings: %s", err) + return nil, util.NewAgentError(fmt.Sprintf("can't convert projectIds to array of strings: %s", err), err) } projectIds := projectIdSlice.([]string) + datasetIdSlice, err := parameters.ConvertAnySliceToTyped(paramsMap["datasetIds"].([]any), "string") if err != nil { - return nil, fmt.Errorf("can't convert datasetIds to array of strings: %s", err) + return nil, util.NewAgentError(fmt.Sprintf("can't convert datasetIds to array of strings: %s", err), err) } datasetIds := datasetIdSlice.([]string) + typesSlice, err := parameters.ConvertAnySliceToTyped(paramsMap["types"].([]any), "string") if err != nil { - return nil, fmt.Errorf("can't convert types to array of strings: %s", err) + return nil, util.NewAgentError(fmt.Sprintf("can't convert types to array of strings: %s", err), err) } types := typesSlice.([]string) @@ -223,17 +228,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) } catalogClient, err = dataplexClientCreator(tokenStr) if err != nil { - return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) + return nil, util.NewClientServerError("error creating client from OAuth access token", http.StatusInternalServerError, err) } } it := catalogClient.SearchEntries(ctx, req) if it == nil { - return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.BigQueryProject()) + return nil, util.NewClientServerError(fmt.Sprintf("failed to create search entries iterator for project %q", source.BigQueryProject()), http.StatusInternalServerError, nil) } var results []Response @@ -243,7 +248,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para break } if err != nil { - break + return nil, util.ProcessGcpError(err) } entrySource := entry.DataplexEntry.GetEntrySource() resp := Response{ diff --git a/internal/tools/bigquery/bigquerysql/bigquerysql.go b/internal/tools/bigquery/bigquerysql/bigquerysql.go index 78685deaa3..062511eacb 100644 --- a/internal/tools/bigquery/bigquerysql/bigquerysql.go +++ b/internal/tools/bigquery/bigquerysql/bigquerysql.go @@ -17,6 +17,7 @@ package bigquerysql import ( "context" "fmt" + "net/http" "reflect" "strings" @@ -27,6 +28,7 @@ import ( bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" "github.com/googleapis/genai-toolbox/internal/tools" bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" bigqueryrestapi "google.golang.org/api/bigquery/v2" ) @@ -103,11 +105,10 @@ type Tool struct { func (t Tool) ToConfig() tools.ToolConfig { return t.Config } - -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters)) @@ -116,7 +117,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } for _, p := range t.Parameters { @@ -127,13 +128,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if arrayParam, ok := p.(*parameters.ArrayParameter); ok { arrayParamValue, ok := value.([]any) if !ok { - return nil, fmt.Errorf("unable to convert parameter `%s` to []any", name) + return nil, util.NewAgentError(fmt.Sprintf("unable to convert parameter `%s` to []any", name), nil) } itemType := arrayParam.GetItems().GetType() var err error value, err = parameters.ConvertAnySliceToTyped(arrayParamValue, itemType) if err != nil { - return nil, fmt.Errorf("unable to convert parameter `%s` from []any to typed slice: %w", name, err) + return nil, util.NewAgentError(fmt.Sprintf("unable to convert parameter `%s` from []any to typed slice", name), err) } } @@ -161,7 +162,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para lowLevelParam.ParameterType.Type = "ARRAY" itemType, err := bqutil.BQTypeStringFromToolType(arrayParam.GetItems().GetType()) if err != nil { - return nil, err + return nil, util.NewAgentError("unable to get BigQuery type from tool parameter type", err) } lowLevelParam.ParameterType.ArrayType = &bigqueryrestapi.QueryParameterType{Type: itemType} @@ -178,7 +179,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Handle scalar types based on their defined type. bqType, err := bqutil.BQTypeStringFromToolType(p.GetType()) if err != nil { - return nil, err + return nil, util.NewAgentError("unable to get BigQuery type from tool parameter type", err) } lowLevelParam.ParameterType.Type = bqType lowLevelParam.ParameterValue.Value = fmt.Sprintf("%v", value) @@ -190,7 +191,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if source.BigQuerySession() != nil { session, err := source.BigQuerySession()(ctx) if err != nil { - return nil, fmt.Errorf("failed to get BigQuery session: %w", err) + return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err) } if session != nil { // Add session ID to the connection properties for subsequent calls. @@ -200,17 +201,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para bqClient, restService, err := source.RetrieveClientAndService(accessToken) if err != nil { - return nil, err + return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err) } dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, connProps) if err != nil { - return nil, fmt.Errorf("query validation failed: %w", err) + return nil, util.ProcessGcpError(err) } statementType := dryRunJob.Statistics.Query.StatementType - - return source.RunSQL(ctx, bqClient, newStatement, statementType, highLevelParams, connProps) + resp, err := source.RunSQL(ctx, bqClient, newStatement, statementType, highLevelParams, connProps) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/bigtable/bigtable.go b/internal/tools/bigtable/bigtable.go index 4c47ca945e..48f659e95e 100644 --- a/internal/tools/bigtable/bigtable.go +++ b/internal/tools/bigtable/bigtable.go @@ -17,12 +17,14 @@ package bigtable import ( "context" "fmt" + "net/http" "cloud.google.com/go/bigtable" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -96,24 +98,28 @@ type Tool struct { func (t Tool) ToConfig() tools.ToolConfig { return t.Config } - -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } - return source.RunSQL(ctx, newStatement, t.Parameters, newParams) + + resp, err := source.RunSQL(ctx, newStatement, t.Parameters, newParams) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cassandra/cassandracql/cassandracql.go b/internal/tools/cassandra/cassandracql/cassandracql.go index 6dcd2a013a..2cdcd92e57 100644 --- a/internal/tools/cassandra/cassandracql/cassandracql.go +++ b/internal/tools/cassandra/cassandracql/cassandracql.go @@ -4,7 +4,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -17,12 +17,14 @@ package cassandracql import ( "context" "fmt" + "net/http" gocql "github.com/apache/cassandra-gocql-driver/v2" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -107,23 +109,27 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { } // Invoke implements tools.Tool. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } - return source.RunSQL(ctx, newStatement, newParams) + resp, err := source.RunSQL(ctx, newStatement, newParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } // Manifest implements tools.Tool. diff --git a/internal/tools/clickhouse/clickhouseexecutesql/clickhouseexecutesql.go b/internal/tools/clickhouse/clickhouseexecutesql/clickhouseexecutesql.go index eefa02c6fa..8b69d71b60 100644 --- a/internal/tools/clickhouse/clickhouseexecutesql/clickhouseexecutesql.go +++ b/internal/tools/clickhouse/clickhouseexecutesql/clickhouseexecutesql.go @@ -17,11 +17,13 @@ package clickhouse import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -87,18 +89,22 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { - return nil, fmt.Errorf("unable to cast sql parameter %s", paramsMap["sql"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to cast sql parameter %s", paramsMap["sql"]), nil) } - return source.RunSQL(ctx, sql, nil) + resp, err := source.RunSQL(ctx, sql, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases.go b/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases.go index 317c462935..900649f4a8 100644 --- a/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases.go +++ b/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases.go @@ -17,11 +17,13 @@ package clickhouse import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -86,10 +88,10 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } // Query to list all databases @@ -97,7 +99,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para out, err := source.RunSQL(ctx, query, nil) if err != nil { - return nil, err + return nil, util.ProcessGeneralError(err) } return out, nil diff --git a/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables.go b/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables.go index 492bc281ad..10fb432d55 100644 --- a/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables.go +++ b/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables.go @@ -17,11 +17,13 @@ package clickhouse import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -90,34 +92,37 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } mapParams := params.AsMap() database, ok := mapParams[databaseKey].(string) if !ok { - return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", databaseKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", databaseKey), nil) } + // Query to list all tables in the specified database + // Note: formatting identifier directly is risky if input is untrusted, but standard for this tool structure. query := fmt.Sprintf("SHOW TABLES FROM %s", database) out, err := source.RunSQL(ctx, query, nil) if err != nil { - return nil, err + return nil, util.ProcessGeneralError(err) } res, ok := out.([]any) if !ok { - return nil, fmt.Errorf("unable to convert result to list") + return nil, util.NewClientServerError("unable to convert result to list", http.StatusInternalServerError, nil) } + var tables []map[string]any for _, item := range res { tableMap, ok := item.(map[string]any) if !ok { - return nil, fmt.Errorf("unexpected type in result: got %T, want map[string]any", item) + return nil, util.NewClientServerError(fmt.Sprintf("unexpected type in result: got %T, want map[string]any", item), http.StatusInternalServerError, nil) } tableMap["database"] = database tables = append(tables, tableMap) diff --git a/internal/tools/clickhouse/clickhousesql/clickhousesql.go b/internal/tools/clickhouse/clickhousesql/clickhousesql.go index 10645d309a..aafd98b2e0 100644 --- a/internal/tools/clickhouse/clickhousesql/clickhousesql.go +++ b/internal/tools/clickhouse/clickhousesql/clickhousesql.go @@ -17,11 +17,13 @@ package clickhouse import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -88,24 +90,28 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params: %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params: %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } - return source.RunSQL(ctx, newStatement, newParams) + resp, err := source.RunSQL(ctx, newStatement, newParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudgda/cloudgda.go b/internal/tools/cloudgda/cloudgda.go index a650c8e4a1..14862909b4 100644 --- a/internal/tools/cloudgda/cloudgda.go +++ b/internal/tools/cloudgda/cloudgda.go @@ -18,11 +18,13 @@ import ( "context" "encoding/json" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -119,17 +121,16 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -// Invoke executes the tool logic -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() query, ok := paramsMap["query"].(string) if !ok { - return nil, fmt.Errorf("query parameter not found or not a string") + return nil, util.NewAgentError("query parameter not found or not a string", nil) } // Parse the access token if provided @@ -138,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para var err error tokenStr, err = accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) } } @@ -154,9 +155,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para bodyBytes, err := json.Marshal(payload) if err != nil { - return nil, fmt.Errorf("failed to marshal request payload: %w", err) + return nil, util.NewClientServerError("failed to marshal request payload", http.StatusInternalServerError, err) } - return source.RunQuery(ctx, tokenStr, bodyBytes) + + resp, err := source.RunQuery(ctx, tokenStr, bodyBytes) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage/cloudhealthcarefhirfetchpage.go b/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage/cloudhealthcarefhirfetchpage.go index 104bf53a73..acd55c61ca 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage/cloudhealthcarefhirfetchpage.go +++ b/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage/cloudhealthcarefhirfetchpage.go @@ -17,11 +17,13 @@ package fhirfetchpage import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -93,24 +95,31 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } url, ok := params.AsMap()[pageURLKey].(string) if !ok { - return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", pageURLKey), nil) } + var tokenStr string if source.UseClientAuthorization() { tokenStr, err = accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) } } - return source.FHIRFetchPage(ctx, url, tokenStr) + + resp, err := source.FHIRFetchPage(ctx, url, tokenStr) + if err != nil { + + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything/cloudhealthcarefhirpatienteverything.go b/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything/cloudhealthcarefhirpatienteverything.go index 40c479cbfd..f81d601e03 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything/cloudhealthcarefhirpatienteverything.go +++ b/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything/cloudhealthcarefhirpatienteverything.go @@ -17,6 +17,7 @@ package fhirpatienteverything import ( "context" "fmt" + "net/http" "strings" "github.com/goccy/go-yaml" @@ -24,6 +25,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/googleapi" ) @@ -116,26 +118,27 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores()) if err != nil { - return nil, err + // ValidateAndFetchStoreID usually returns input validation errors + return nil, util.NewAgentError("failed to validate store ID", err) } patientID, ok := params.AsMap()[patientIDKey].(string) if !ok { - return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", patientIDKey), nil) } var tokenStr string if source.UseClientAuthorization() { tokenStr, err = accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) } } @@ -143,11 +146,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if val, ok := params.AsMap()[typeFilterKey]; ok { types, ok := val.([]any) if !ok { - return nil, fmt.Errorf("invalid '%s' parameter; expected a string array", typeFilterKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid '%s' parameter; expected a string array", typeFilterKey), nil) } typeFilterSlice, err := parameters.ConvertAnySliceToTyped(types, "string") if err != nil { - return nil, fmt.Errorf("can't convert '%s' to array of strings: %s", typeFilterKey, err) + return nil, util.NewAgentError(fmt.Sprintf("can't convert '%s' to array of strings: %s", typeFilterKey, err), err) } if len(typeFilterSlice.([]string)) != 0 { opts = append(opts, googleapi.QueryParameter("_type", strings.Join(typeFilterSlice.([]string), ","))) @@ -156,13 +159,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if since, ok := params.AsMap()[sinceFilterKey]; ok { sinceStr, ok := since.(string) if !ok { - return nil, fmt.Errorf("invalid '%s' parameter; expected a string", sinceFilterKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid '%s' parameter; expected a string", sinceFilterKey), nil) } if sinceStr != "" { opts = append(opts, googleapi.QueryParameter("_since", sinceStr)) } } - return source.FHIRPatientEverything(storeID, patientID, tokenStr, opts) + + resp, err := source.FHIRPatientEverything(storeID, patientID, tokenStr, opts) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch/cloudhealthcarefhirpatientsearch.go b/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch/cloudhealthcarefhirpatientsearch.go index 08283c8b88..5a25a5028c 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch/cloudhealthcarefhirpatientsearch.go +++ b/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch/cloudhealthcarefhirpatientsearch.go @@ -17,6 +17,7 @@ package fhirpatientsearch import ( "context" "fmt" + "net/http" "strings" "github.com/goccy/go-yaml" @@ -24,6 +25,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/googleapi" ) @@ -150,22 +152,22 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores()) if err != nil { - return nil, err + return nil, util.NewAgentError("failed to validate store ID", err) } var tokenStr string if source.UseClientAuthorization() { tokenStr, err = accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) } } @@ -179,14 +181,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para var ok bool summary, ok = v.(bool) if !ok { - return nil, fmt.Errorf("invalid '%s' parameter; expected a boolean", summaryKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid '%s' parameter; expected a boolean", summaryKey), nil) } continue } val, ok := v.(string) if !ok { - return nil, fmt.Errorf("invalid parameter '%s'; expected a string", k) + return nil, util.NewAgentError(fmt.Sprintf("invalid parameter '%s'; expected a string", k), nil) } if val == "" { continue @@ -205,7 +207,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } parts := strings.Split(val, "/") if len(parts) != 2 { - return nil, fmt.Errorf("invalid '%s' format; expected YYYY-MM-DD/YYYY-MM-DD", k) + return nil, util.NewAgentError(fmt.Sprintf("invalid '%s' format; expected YYYY-MM-DD/YYYY-MM-DD", k), nil) } var values []string if parts[0] != "" { @@ -229,13 +231,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para case familyNameKey: opts = append(opts, googleapi.QueryParameter("family", val)) default: - return nil, fmt.Errorf("unexpected parameter key %q", k) + return nil, util.NewAgentError(fmt.Sprintf("unexpected parameter key %q", k), nil) } } if summary { opts = append(opts, googleapi.QueryParameter("_summary", "text")) } - return source.FHIRPatientSearch(storeID, tokenStr, opts) + resp, err := source.FHIRPatientSearch(storeID, tokenStr, opts) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudhealthcare/cloudhealthcaregetdataset/cloudhealthcaregetdataset.go b/internal/tools/cloudhealthcare/cloudhealthcaregetdataset/cloudhealthcaregetdataset.go index 23b34a489c..6924233c74 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaregetdataset/cloudhealthcaregetdataset.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaregetdataset/cloudhealthcaregetdataset.go @@ -17,11 +17,13 @@ package gethealthcaredataset import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/healthcare/v1" ) @@ -90,19 +92,23 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } var tokenStr string if source.UseClientAuthorization() { tokenStr, err = accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) } } - return source.GetDataset(tokenStr) + resp, err := source.GetDataset(tokenStr) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstore/cloudhealthcaregetdicomstore.go b/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstore/cloudhealthcaregetdicomstore.go index f3015ea801..2ba82fa4cf 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstore/cloudhealthcaregetdicomstore.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstore/cloudhealthcaregetdicomstore.go @@ -17,12 +17,14 @@ package getdicomstore import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/healthcare/v1" ) @@ -107,23 +109,27 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores()) if err != nil { - return nil, err + return nil, util.NewAgentError("failed to validate store ID", err) } var tokenStr string if source.UseClientAuthorization() { tokenStr, err = accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) } } - return source.GetDICOMStore(storeID, tokenStr) + resp, err := source.GetDICOMStore(storeID, tokenStr) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstoremetrics/cloudhealthcaregetdicomstoremetrics.go b/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstoremetrics/cloudhealthcaregetdicomstoremetrics.go index 1a3c23b7be..40b8f3a247 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstoremetrics/cloudhealthcaregetdicomstoremetrics.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstoremetrics/cloudhealthcaregetdicomstoremetrics.go @@ -17,12 +17,14 @@ package getdicomstoremetrics import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/healthcare/v1" ) @@ -107,23 +109,27 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores()) if err != nil { - return nil, err + return nil, util.NewAgentError("failed to validate store ID", err) } var tokenStr string if source.UseClientAuthorization() { tokenStr, err = accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) } } - return source.GetDICOMStoreMetrics(storeID, tokenStr) + resp, err := source.GetDICOMStoreMetrics(storeID, tokenStr) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudhealthcare/cloudhealthcaregetfhirresource/cloudhealthcaregetfhirresource.go b/internal/tools/cloudhealthcare/cloudhealthcaregetfhirresource/cloudhealthcaregetfhirresource.go index 2d1d316489..57aa815361 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaregetfhirresource/cloudhealthcaregetfhirresource.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaregetfhirresource/cloudhealthcaregetfhirresource.go @@ -17,12 +17,14 @@ package getfhirresource import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -112,32 +114,36 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores()) if err != nil { - return nil, err + return nil, util.NewAgentError("failed to validate store ID", err) } resType, ok := params.AsMap()[typeKey].(string) if !ok { - return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", typeKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", typeKey), nil) } resID, ok := params.AsMap()[idKey].(string) if !ok { - return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", idKey), nil) } var tokenStr string if source.UseClientAuthorization() { tokenStr, err = accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) } } - return source.GetFHIRResource(storeID, resType, resID, tokenStr) + resp, err := source.GetFHIRResource(storeID, resType, resID, tokenStr) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstore/cloudhealthcaregetfhirstore.go b/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstore/cloudhealthcaregetfhirstore.go index 633df3b9dc..e4ec7043eb 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstore/cloudhealthcaregetfhirstore.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstore/cloudhealthcaregetfhirstore.go @@ -17,12 +17,14 @@ package getfhirstore import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/healthcare/v1" ) @@ -107,23 +109,27 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores()) if err != nil { - return nil, err + return nil, util.NewAgentError("failed to validate store ID", err) } var tokenStr string if source.UseClientAuthorization() { tokenStr, err = accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) } } - return source.GetFHIRStore(storeID, tokenStr) + resp, err := source.GetFHIRStore(storeID, tokenStr) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstoremetrics/cloudhealthcaregetfhirstoremetrics.go b/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstoremetrics/cloudhealthcaregetfhirstoremetrics.go index 39088122ba..d3e4eb07fb 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstoremetrics/cloudhealthcaregetfhirstoremetrics.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstoremetrics/cloudhealthcaregetfhirstoremetrics.go @@ -17,12 +17,14 @@ package getfhirstoremetrics import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/healthcare/v1" ) @@ -107,23 +109,27 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores()) if err != nil { - return nil, err + return nil, util.NewAgentError("failed to validate store ID", err) } var tokenStr string if source.UseClientAuthorization() { tokenStr, err = accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) } } - return source.GetFHIRStoreMetrics(storeID, tokenStr) + resp, err := source.GetFHIRStoreMetrics(storeID, tokenStr) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudhealthcare/cloudhealthcarelistdicomstores/cloudhealthcarelistdicomstores.go b/internal/tools/cloudhealthcare/cloudhealthcarelistdicomstores/cloudhealthcarelistdicomstores.go index 612a455b39..fb43e9d353 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcarelistdicomstores/cloudhealthcarelistdicomstores.go +++ b/internal/tools/cloudhealthcare/cloudhealthcarelistdicomstores/cloudhealthcarelistdicomstores.go @@ -17,11 +17,13 @@ package listdicomstores import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/healthcare/v1" ) @@ -90,19 +92,23 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } var tokenStr string if source.UseClientAuthorization() { tokenStr, err = accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) } } - return source.ListDICOMStores(tokenStr) + resp, err := source.ListDICOMStores(tokenStr) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudhealthcare/cloudhealthcarelistfhirstores/cloudhealthcarelistfhirstores.go b/internal/tools/cloudhealthcare/cloudhealthcarelistfhirstores/cloudhealthcarelistfhirstores.go index bb1e182416..203c666b12 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcarelistfhirstores/cloudhealthcarelistfhirstores.go +++ b/internal/tools/cloudhealthcare/cloudhealthcarelistfhirstores/cloudhealthcarelistfhirstores.go @@ -17,11 +17,13 @@ package listfhirstores import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/healthcare/v1" ) @@ -90,19 +92,23 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } var tokenStr string if source.UseClientAuthorization() { tokenStr, err = accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) } } - return source.ListFHIRStores(tokenStr) + resp, err := source.ListFHIRStores(tokenStr) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudhealthcare/cloudhealthcareretrieverendereddicominstance/cloudhealthcareretrieverendereddicominstance.go b/internal/tools/cloudhealthcare/cloudhealthcareretrieverendereddicominstance/cloudhealthcareretrieverendereddicominstance.go index c3379142ce..711a0cfc86 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcareretrieverendereddicominstance/cloudhealthcareretrieverendereddicominstance.go +++ b/internal/tools/cloudhealthcare/cloudhealthcareretrieverendereddicominstance/cloudhealthcareretrieverendereddicominstance.go @@ -17,12 +17,14 @@ package retrieverendereddicominstance import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -117,40 +119,44 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores()) if err != nil { - return nil, err + return nil, util.NewAgentError("failed to validate store ID", err) } var tokenStr string if source.UseClientAuthorization() { tokenStr, err = accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) } } study, ok := params.AsMap()[studyInstanceUIDKey].(string) if !ok { - return nil, fmt.Errorf("invalid '%s' parameter; expected a string", studyInstanceUIDKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid '%s' parameter; expected a string", studyInstanceUIDKey), nil) } series, ok := params.AsMap()[seriesInstanceUIDKey].(string) if !ok { - return nil, fmt.Errorf("invalid '%s' parameter; expected a string", seriesInstanceUIDKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid '%s' parameter; expected a string", seriesInstanceUIDKey), nil) } sop, ok := params.AsMap()[sopInstanceUIDKey].(string) if !ok { - return nil, fmt.Errorf("invalid '%s' parameter; expected a string", sopInstanceUIDKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid '%s' parameter; expected a string", sopInstanceUIDKey), nil) } frame, ok := params.AsMap()[frameNumberKey].(int) if !ok { - return nil, fmt.Errorf("invalid '%s' parameter; expected an integer", frameNumberKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid '%s' parameter; expected an integer", frameNumberKey), nil) } - return source.RetrieveRenderedDICOMInstance(storeID, study, series, sop, frame, tokenStr) + resp, err := source.RetrieveRenderedDICOMInstance(storeID, study, series, sop, frame, tokenStr) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudhealthcare/cloudhealthcaresearchdicominstances/cloudhealthcaresearchdicominstances.go b/internal/tools/cloudhealthcare/cloudhealthcaresearchdicominstances/cloudhealthcaresearchdicominstances.go index 1de1f0b12f..a3183238e8 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaresearchdicominstances/cloudhealthcaresearchdicominstances.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaresearchdicominstances/cloudhealthcaresearchdicominstances.go @@ -17,6 +17,7 @@ package searchdicominstances import ( "context" "fmt" + "net/http" "strings" "github.com/goccy/go-yaml" @@ -24,6 +25,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/googleapi" ) @@ -131,33 +133,33 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores()) if err != nil { - return nil, err + return nil, util.NewAgentError("failed to validate store ID", err) } var tokenStr string if source.UseClientAuthorization() { tokenStr, err = accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) } } opts, err := common.ParseDICOMSearchParameters(params, []string{sopInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey}) if err != nil { - return nil, err + return nil, util.NewAgentError("failed to parse DICOM search parameters", err) } paramsMap := params.AsMap() dicomWebPath := "instances" if studyInstanceUID, ok := paramsMap[studyInstanceUIDKey]; ok { id, ok := studyInstanceUID.(string) if !ok { - return nil, fmt.Errorf("invalid '%s' parameter; expected a string", studyInstanceUIDKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid '%s' parameter; expected a string", studyInstanceUIDKey), nil) } if id != "" { dicomWebPath = fmt.Sprintf("studies/%s/instances", id) @@ -166,7 +168,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if seriesInstanceUID, ok := paramsMap[seriesInstanceUIDKey]; ok { id, ok := seriesInstanceUID.(string) if !ok { - return nil, fmt.Errorf("invalid '%s' parameter; expected a string", seriesInstanceUIDKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid '%s' parameter; expected a string", seriesInstanceUIDKey), nil) } if id != "" { if dicomWebPath != "instances" { @@ -176,7 +178,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } } - return source.SearchDICOM(t.Type, storeID, dicomWebPath, tokenStr, opts) + resp, err := source.SearchDICOM(t.Type, storeID, dicomWebPath, tokenStr, opts) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomseries/cloudhealthcaresearchdicomseries.go b/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomseries/cloudhealthcaresearchdicomseries.go index dac124e1ee..75735b5db5 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomseries/cloudhealthcaresearchdicomseries.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomseries/cloudhealthcaresearchdicomseries.go @@ -17,12 +17,14 @@ package searchdicomseries import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/googleapi" ) @@ -128,40 +130,44 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores()) if err != nil { - return nil, err + return nil, util.NewAgentError("failed to validate store ID", err) } var tokenStr string if source.UseClientAuthorization() { tokenStr, err = accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) } } opts, err := common.ParseDICOMSearchParameters(params, []string{seriesInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey}) if err != nil { - return nil, err + return nil, util.NewAgentError("failed to parse DICOM search parameters", err) } paramsMap := params.AsMap() dicomWebPath := "series" if studyInstanceUID, ok := paramsMap[studyInstanceUIDKey]; ok { id, ok := studyInstanceUID.(string) if !ok { - return nil, fmt.Errorf("invalid '%s' parameter; expected a string", studyInstanceUIDKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid '%s' parameter; expected a string", studyInstanceUIDKey), nil) } if id != "" { dicomWebPath = fmt.Sprintf("studies/%s/series", id) } } - return source.SearchDICOM(t.Type, storeID, dicomWebPath, tokenStr, opts) + resp, err := source.SearchDICOM(t.Type, storeID, dicomWebPath, tokenStr, opts) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies/cloudhealthcaresearchdicomstudies.go b/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies/cloudhealthcaresearchdicomstudies.go index 7d51b22d83..d1f2a2ed30 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies/cloudhealthcaresearchdicomstudies.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies/cloudhealthcaresearchdicomstudies.go @@ -17,12 +17,14 @@ package searchdicomstudies import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/googleapi" ) @@ -124,28 +126,32 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores()) if err != nil { - return nil, err + return nil, util.NewAgentError("failed to validate store ID", err) } var tokenStr string if source.UseClientAuthorization() { tokenStr, err = accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) + return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err) } } opts, err := common.ParseDICOMSearchParameters(params, []string{studyInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey}) if err != nil { - return nil, err + return nil, util.NewAgentError("failed to parse DICOM search parameters", err) } dicomWebPath := "studies" - return source.SearchDICOM(t.Type, storeID, dicomWebPath, tokenStr, opts) + resp, err := source.SearchDICOM(t.Type, storeID, dicomWebPath, tokenStr, opts) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudloggingadmin/cloudloggingadminlistlognames/cloudloggingadminlistlognames.go b/internal/tools/cloudloggingadmin/cloudloggingadminlistlognames/cloudloggingadminlistlognames.go index 063fbba334..73253b7a9a 100644 --- a/internal/tools/cloudloggingadmin/cloudloggingadminlistlognames/cloudloggingadminlistlognames.go +++ b/internal/tools/cloudloggingadmin/cloudloggingadminlistlognames/cloudloggingadminlistlognames.go @@ -16,11 +16,13 @@ package cloudloggingadminlistlognames import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -89,10 +91,10 @@ type Tool struct { Parameters parameters.Parameters `yaml:"parameters"` } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } limit := defaultLimit @@ -100,18 +102,22 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if val, ok := paramsMap["limit"].(int); ok && val > 0 { limit = val } else if ok && val < 0 { - return nil, fmt.Errorf("limit must be greater than or equal to 1") + return nil, util.NewAgentError("limit must be greater than or equal to 1", nil) } tokenString := "" if source.UseClientAuthorization() { tokenString, err = accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("failed to parse access token: %w", err) + return nil, util.NewClientServerError("failed to parse access token", http.StatusUnauthorized, err) } } - return source.ListLogNames(ctx, limit, tokenString) + resp, err := source.ListLogNames(ctx, limit, tokenString) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudloggingadmin/cloudloggingadminlistresourcetypes/cloudloggingadminlistresourcetypes.go b/internal/tools/cloudloggingadmin/cloudloggingadminlistresourcetypes/cloudloggingadminlistresourcetypes.go index 1326bf037c..ce171ec8aa 100644 --- a/internal/tools/cloudloggingadmin/cloudloggingadminlistresourcetypes/cloudloggingadminlistresourcetypes.go +++ b/internal/tools/cloudloggingadmin/cloudloggingadminlistresourcetypes/cloudloggingadminlistresourcetypes.go @@ -16,11 +16,13 @@ package cloudloggingadminlistresourcetypes import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -84,21 +86,25 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } tokenString := "" if source.UseClientAuthorization() { tokenString, err = accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("failed to parse access token: %w", err) + return nil, util.NewClientServerError("failed to parse access token", http.StatusUnauthorized, err) } } - return source.ListResourceTypes(ctx, tokenString) + resp, err := source.ListResourceTypes(ctx, tokenString) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudloggingadmin/cloudloggingadminquerylogs/cloudloggingadminquerylogs.go b/internal/tools/cloudloggingadmin/cloudloggingadminquerylogs/cloudloggingadminquerylogs.go index ab62ef3510..b5216fac02 100644 --- a/internal/tools/cloudloggingadmin/cloudloggingadminquerylogs/cloudloggingadminquerylogs.go +++ b/internal/tools/cloudloggingadmin/cloudloggingadminquerylogs/cloudloggingadminquerylogs.go @@ -16,6 +16,7 @@ package cloudloggingadminquerylogs import ( "context" "fmt" + "net/http" "time" "github.com/goccy/go-yaml" @@ -23,6 +24,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" cla "github.com/googleapis/genai-toolbox/internal/sources/cloudloggingadmin" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -104,10 +106,10 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } // Parse parameters @@ -119,7 +121,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if val, ok := paramsMap["limit"].(int); ok && val > 0 { limit = val } else if ok && val < 0 { - return nil, fmt.Errorf("limit must be greater than or equal to 1") + return nil, util.NewAgentError("limit must be greater than or equal to 1", nil) } // Check for verbosity of output @@ -129,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para var filter string if f, ok := paramsMap["filter"].(string); ok { if len(f) == 0 { - return nil, fmt.Errorf("filter cannot be empty if provided") + return nil, util.NewAgentError("filter cannot be empty if provided", nil) } filter = f } @@ -138,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para var startTime string if val, ok := paramsMap["startTime"].(string); ok && val != "" { if _, err := time.Parse(time.RFC3339, val); err != nil { - return nil, fmt.Errorf("startTime must be in RFC3339 format (e.g., 2025-12-09T00:00:00Z): %w", err) + return nil, util.NewAgentError(fmt.Sprintf("startTime must be in RFC3339 format (e.g., 2025-12-09T00:00:00Z): %v", err), err) } startTime = val } else { @@ -149,7 +151,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para var endTime string if val, ok := paramsMap["endTime"].(string); ok && val != "" { if _, err := time.Parse(time.RFC3339, val); err != nil { - return nil, fmt.Errorf("endTime must be in RFC3339 format (e.g., 2025-12-09T23:59:59Z): %w", err) + return nil, util.NewAgentError(fmt.Sprintf("endTime must be in RFC3339 format (e.g., 2025-12-09T23:59:59Z): %v", err), err) } endTime = val } @@ -158,7 +160,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if source.UseClientAuthorization() { tokenString, err = accessToken.ParseBearerToken() if err != nil { - return nil, fmt.Errorf("failed to parse access token: %w", err) + return nil, util.NewClientServerError("failed to parse access token", http.StatusUnauthorized, err) } } @@ -171,7 +173,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para Limit: limit, } - return source.QueryLogs(ctx, queryParams, tokenString) + resp, err := source.QueryLogs(ctx, queryParams, tokenString) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudmonitoring/cloudmonitoring.go b/internal/tools/cloudmonitoring/cloudmonitoring.go index 3d28b61f68..b3524b58bd 100644 --- a/internal/tools/cloudmonitoring/cloudmonitoring.go +++ b/internal/tools/cloudmonitoring/cloudmonitoring.go @@ -23,6 +23,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -93,22 +94,26 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() projectID, ok := paramsMap["projectId"].(string) if !ok { - return nil, fmt.Errorf("projectId parameter not found or not a string") + return nil, util.NewAgentError("projectId parameter not found or not a string", nil) } query, ok := paramsMap["query"].(string) if !ok { - return nil, fmt.Errorf("query parameter not found or not a string") + return nil, util.NewAgentError("query parameter not found or not a string", nil) } - return source.RunQuery(projectID, query) + resp, err := source.RunQuery(projectID, query) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudsql/cloudsqlcloneinstance/cloudsqlcloneinstance.go b/internal/tools/cloudsql/cloudsqlcloneinstance/cloudsqlcloneinstance.go index 03e5a75390..786fa45ced 100644 --- a/internal/tools/cloudsql/cloudsqlcloneinstance/cloudsqlcloneinstance.go +++ b/internal/tools/cloudsql/cloudsqlcloneinstance/cloudsqlcloneinstance.go @@ -17,11 +17,13 @@ package cloudsqlcloneinstance import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" sqladmin "google.golang.org/api/sqladmin/v1" ) @@ -124,31 +126,35 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok { - return nil, fmt.Errorf("error casting 'project' parameter: %v", paramsMap["project"]) + return nil, util.NewAgentError(fmt.Sprintf("error casting 'project' parameter: %v", paramsMap["project"]), nil) } sourceInstanceName, ok := paramsMap["sourceInstanceName"].(string) if !ok { - return nil, fmt.Errorf("error casting 'sourceInstanceName' parameter: %v", paramsMap["sourceInstanceName"]) + return nil, util.NewAgentError(fmt.Sprintf("error casting 'sourceInstanceName' parameter: %v", paramsMap["sourceInstanceName"]), nil) } destinationInstanceName, ok := paramsMap["destinationInstanceName"].(string) if !ok { - return nil, fmt.Errorf("error casting 'destinationInstanceName' parameter: %v", paramsMap["destinationInstanceName"]) + return nil, util.NewAgentError(fmt.Sprintf("error casting 'destinationInstanceName' parameter: %v", paramsMap["destinationInstanceName"]), nil) } pointInTime, _ := paramsMap["pointInTime"].(string) preferredZone, _ := paramsMap["preferredZone"].(string) preferredSecondaryZone, _ := paramsMap["preferredSecondaryZone"].(string) - return source.CloneInstance(ctx, project, sourceInstanceName, destinationInstanceName, pointInTime, preferredZone, preferredSecondaryZone, string(accessToken)) + resp, err := source.CloneInstance(ctx, project, sourceInstanceName, destinationInstanceName, pointInTime, preferredZone, preferredSecondaryZone, string(accessToken)) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudsql/cloudsqlcreatebackup/cloudsqlcreatebackup.go b/internal/tools/cloudsql/cloudsqlcreatebackup/cloudsqlcreatebackup.go index e5b5b6c3b9..926efeee1d 100644 --- a/internal/tools/cloudsql/cloudsqlcreatebackup/cloudsqlcreatebackup.go +++ b/internal/tools/cloudsql/cloudsqlcreatebackup/cloudsqlcreatebackup.go @@ -17,11 +17,13 @@ package cloudsqlcreatebackup import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/sqladmin/v1" ) @@ -120,26 +122,30 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok { - return nil, fmt.Errorf("error casting 'project' parameter: %v", paramsMap["project"]) + return nil, util.NewAgentError(fmt.Sprintf("error casting 'project' parameter: %v", paramsMap["project"]), nil) } instance, ok := paramsMap["instance"].(string) if !ok { - return nil, fmt.Errorf("error casting 'instance' parameter: %v", paramsMap["instance"]) + return nil, util.NewAgentError(fmt.Sprintf("error casting 'instance' parameter: %v", paramsMap["instance"]), nil) } location, _ := paramsMap["location"].(string) description, _ := paramsMap["backup_description"].(string) - return source.InsertBackupRun(ctx, project, instance, location, description, string(accessToken)) + resp, err := source.InsertBackupRun(ctx, project, instance, location, description, string(accessToken)) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudsql/cloudsqlcreatedatabase/cloudsqlcreatedatabase.go b/internal/tools/cloudsql/cloudsqlcreatedatabase/cloudsqlcreatedatabase.go index 3b1573c70c..422e60bf3c 100644 --- a/internal/tools/cloudsql/cloudsqlcreatedatabase/cloudsqlcreatedatabase.go +++ b/internal/tools/cloudsql/cloudsqlcreatedatabase/cloudsqlcreatedatabase.go @@ -17,11 +17,13 @@ package cloudsqlcreatedatabase import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -117,27 +119,31 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok { - return nil, fmt.Errorf("missing 'project' parameter") + return nil, util.NewAgentError("missing 'project' parameter", nil) } instance, ok := paramsMap["instance"].(string) if !ok { - return nil, fmt.Errorf("missing 'instance' parameter") + return nil, util.NewAgentError("missing 'instance' parameter", nil) } name, ok := paramsMap["name"].(string) if !ok { - return nil, fmt.Errorf("missing 'name' parameter") + return nil, util.NewAgentError("missing 'name' parameter", nil) } - return source.CreateDatabase(ctx, name, project, instance, string(accessToken)) + resp, err := source.CreateDatabase(ctx, name, project, instance, string(accessToken)) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudsql/cloudsqlcreateusers/cloudsqlcreateusers.go b/internal/tools/cloudsql/cloudsqlcreateusers/cloudsqlcreateusers.go index 1594b81dd3..101ea45f96 100644 --- a/internal/tools/cloudsql/cloudsqlcreateusers/cloudsqlcreateusers.go +++ b/internal/tools/cloudsql/cloudsqlcreateusers/cloudsqlcreateusers.go @@ -17,11 +17,13 @@ package cloudsqlcreateusers import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -119,30 +121,38 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok { - return nil, fmt.Errorf("missing 'project' parameter") + return nil, util.NewAgentError("missing 'project' parameter", nil) } instance, ok := paramsMap["instance"].(string) if !ok { - return nil, fmt.Errorf("missing 'instance' parameter") + return nil, util.NewAgentError("missing 'instance' parameter", nil) } name, ok := paramsMap["name"].(string) if !ok { - return nil, fmt.Errorf("missing 'name' parameter") + return nil, util.NewAgentError("missing 'name' parameter", nil) } iamUser, _ := paramsMap["iamUser"].(bool) password, _ := paramsMap["password"].(string) - return source.CreateUsers(ctx, project, instance, name, password, iamUser, string(accessToken)) + if !iamUser && password == "" { + return nil, util.NewAgentError("missing 'password' parameter for non-IAM user", nil) + } + + resp, err := source.CreateUsers(ctx, project, instance, name, password, iamUser, string(accessToken)) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudsql/cloudsqlgetinstances/cloudsqlgetinstances.go b/internal/tools/cloudsql/cloudsqlgetinstances/cloudsqlgetinstances.go index d65aa749be..8602ab2740 100644 --- a/internal/tools/cloudsql/cloudsqlgetinstances/cloudsqlgetinstances.go +++ b/internal/tools/cloudsql/cloudsqlgetinstances/cloudsqlgetinstances.go @@ -17,11 +17,13 @@ package cloudsqlgetinstances import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -117,23 +119,27 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() projectId, ok := paramsMap["projectId"].(string) if !ok { - return nil, fmt.Errorf("missing 'projectId' parameter") + return nil, util.NewAgentError("missing 'projectId' parameter", nil) } instanceId, ok := paramsMap["instanceId"].(string) if !ok { - return nil, fmt.Errorf("missing 'instanceId' parameter") + return nil, util.NewAgentError("missing 'instanceId' parameter", nil) } - return source.GetInstance(ctx, projectId, instanceId, string(accessToken)) + resp, err := source.GetInstance(ctx, projectId, instanceId, string(accessToken)) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudsql/cloudsqllistdatabases/cloudsqllistdatabases.go b/internal/tools/cloudsql/cloudsqllistdatabases/cloudsqllistdatabases.go index f185862622..41ebf08fe2 100644 --- a/internal/tools/cloudsql/cloudsqllistdatabases/cloudsqllistdatabases.go +++ b/internal/tools/cloudsql/cloudsqllistdatabases/cloudsqllistdatabases.go @@ -17,11 +17,13 @@ package cloudsqllistdatabases import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -116,23 +118,27 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok { - return nil, fmt.Errorf("missing 'project' parameter") + return nil, util.NewAgentError("missing 'project' parameter", nil) } instance, ok := paramsMap["instance"].(string) if !ok { - return nil, fmt.Errorf("missing 'instance' parameter") + return nil, util.NewAgentError("missing 'instance' parameter", nil) } - return source.ListDatabase(ctx, project, instance, string(accessToken)) + resp, err := source.ListDatabase(ctx, project, instance, string(accessToken)) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudsql/cloudsqllistinstances/cloudsqllistinstances.go b/internal/tools/cloudsql/cloudsqllistinstances/cloudsqllistinstances.go index 8a032b73e9..9c869eaae6 100644 --- a/internal/tools/cloudsql/cloudsqllistinstances/cloudsqllistinstances.go +++ b/internal/tools/cloudsql/cloudsqllistinstances/cloudsqllistinstances.go @@ -17,11 +17,13 @@ package cloudsqllistinstances import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -115,19 +117,23 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok { - return nil, fmt.Errorf("missing 'project' parameter") + return nil, util.NewAgentError("missing 'project' parameter", nil) } - return source.ListInstance(ctx, project, string(accessToken)) + resp, err := source.ListInstance(ctx, project, string(accessToken)) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudsql/cloudsqlrestorebackup/cloudsqlrestorebackup.go b/internal/tools/cloudsql/cloudsqlrestorebackup/cloudsqlrestorebackup.go index 84ae63b3f9..a4e909d157 100644 --- a/internal/tools/cloudsql/cloudsqlrestorebackup/cloudsqlrestorebackup.go +++ b/internal/tools/cloudsql/cloudsqlrestorebackup/cloudsqlrestorebackup.go @@ -17,11 +17,13 @@ package cloudsqlrestorebackup import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/sqladmin/v1" ) @@ -120,29 +122,33 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() targetProject, ok := paramsMap["target_project"].(string) if !ok { - return nil, fmt.Errorf("error casting 'target_project' parameter: %v", paramsMap["target_project"]) + return nil, util.NewAgentError(fmt.Sprintf("error casting 'target_project' parameter: %v", paramsMap["target_project"]), nil) } targetInstance, ok := paramsMap["target_instance"].(string) if !ok { - return nil, fmt.Errorf("error casting 'target_instance' parameter: %v", paramsMap["target_instance"]) + return nil, util.NewAgentError(fmt.Sprintf("error casting 'target_instance' parameter: %v", paramsMap["target_instance"]), nil) } backupID, ok := paramsMap["backup_id"].(string) if !ok { - return nil, fmt.Errorf("error casting 'backup_id' parameter: %v", paramsMap["backup_id"]) + return nil, util.NewAgentError(fmt.Sprintf("error casting 'backup_id' parameter: %v", paramsMap["backup_id"]), nil) } sourceProject, _ := paramsMap["source_project"].(string) sourceInstance, _ := paramsMap["source_instance"].(string) - return source.RestoreBackup(ctx, targetProject, targetInstance, sourceProject, sourceInstance, backupID, string(accessToken)) + resp, err := source.RestoreBackup(ctx, targetProject, targetInstance, sourceProject, sourceInstance, backupID, string(accessToken)) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudsql/cloudsqlwaitforoperation/cloudsqlwaitforoperation.go b/internal/tools/cloudsql/cloudsqlwaitforoperation/cloudsqlwaitforoperation.go index e6d40885bf..610330ad65 100644 --- a/internal/tools/cloudsql/cloudsqlwaitforoperation/cloudsqlwaitforoperation.go +++ b/internal/tools/cloudsql/cloudsqlwaitforoperation/cloudsqlwaitforoperation.go @@ -17,12 +17,14 @@ package cloudsqlwaitforoperation import ( "context" "fmt" + "net/http" "time" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/sqladmin/v1" ) @@ -210,21 +212,21 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok { - return nil, fmt.Errorf("missing 'project' parameter") + return nil, util.NewAgentError("missing 'project' parameter", nil) } operationID, ok := paramsMap["operation"].(string) if !ok { - return nil, fmt.Errorf("missing 'operation' parameter") + return nil, util.NewAgentError("missing 'operation' parameter", nil) } ctx, cancel := context.WithTimeout(ctx, 30*time.Minute) @@ -232,7 +234,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para service, err := source.GetService(ctx, string(accessToken)) if err != nil { - return nil, err + return nil, util.ProcessGcpError(err) } delay := t.Delay @@ -244,13 +246,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para for retries < maxRetries { select { case <-ctx.Done(): - return nil, fmt.Errorf("timed out waiting for operation: %w", ctx.Err()) + return nil, util.NewClientServerError("timed out waiting for operation", http.StatusRequestTimeout, ctx.Err()) default: } op, err := source.GetWaitForOperations(ctx, service, project, operationID, cloudSQLConnectionMessageTemplate, delay) if err != nil { - return nil, err + return nil, util.ProcessGcpError(err) } else if op != nil { return op, nil } @@ -262,7 +264,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } retries++ } - return nil, fmt.Errorf("exceeded max retries waiting for operation") + return nil, util.NewClientServerError("exceeded max retries waiting for operation", http.StatusGatewayTimeout, fmt.Errorf("exceeded max retries")) } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance/cloudsqlmssqlcreateinstance.go b/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance/cloudsqlmssqlcreateinstance.go index 23d4d2d2e4..7adbd4dc2c 100644 --- a/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance/cloudsqlmssqlcreateinstance.go +++ b/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance/cloudsqlmssqlcreateinstance.go @@ -17,12 +17,14 @@ package cloudsqlmssqlcreateinstance import ( "context" "fmt" + "net/http" "strings" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/sqladmin/v1" ) @@ -121,33 +123,33 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok { - return nil, fmt.Errorf("error casting 'project' parameter: %s", paramsMap["project"]) + return nil, util.NewAgentError(fmt.Sprintf("error casting 'project' parameter: %s", paramsMap["project"]), nil) } name, ok := paramsMap["name"].(string) if !ok { - return nil, fmt.Errorf("error casting 'name' parameter: %s", paramsMap["name"]) + return nil, util.NewAgentError(fmt.Sprintf("error casting 'name' parameter: %s", paramsMap["name"]), nil) } dbVersion, ok := paramsMap["databaseVersion"].(string) if !ok { - return nil, fmt.Errorf("error casting 'databaseVersion' parameter: %s", paramsMap["databaseVersion"]) + return nil, util.NewAgentError(fmt.Sprintf("error casting 'databaseVersion' parameter: %s", paramsMap["databaseVersion"]), nil) } rootPassword, ok := paramsMap["rootPassword"].(string) if !ok { - return nil, fmt.Errorf("error casting 'rootPassword' parameter: %s", paramsMap["rootPassword"]) + return nil, util.NewAgentError(fmt.Sprintf("error casting 'rootPassword' parameter: %s", paramsMap["rootPassword"]), nil) } editionPreset, ok := paramsMap["editionPreset"].(string) if !ok { - return nil, fmt.Errorf("error casting 'editionPreset' parameter: %s", paramsMap["editionPreset"]) + return nil, util.NewAgentError(fmt.Sprintf("error casting 'editionPreset' parameter: %s", paramsMap["editionPreset"]), nil) } settings := sqladmin.Settings{} switch strings.ToLower(editionPreset) { @@ -164,9 +166,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para settings.DataDiskSizeGb = 100 settings.DataDiskType = "PD_SSD" default: - return nil, fmt.Errorf("invalid 'editionPreset': %q. Must be either 'Production' or 'Development'", editionPreset) + return nil, util.NewAgentError(fmt.Sprintf("invalid 'editionPreset': %q. Must be either 'Production' or 'Development'", editionPreset), nil) } - return source.CreateInstance(ctx, project, name, dbVersion, rootPassword, settings, string(accessToken)) + resp, err := source.CreateInstance(ctx, project, name, dbVersion, rootPassword, settings, string(accessToken)) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance/cloudsqlmysqlcreateinstance.go b/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance/cloudsqlmysqlcreateinstance.go index ab78fdc6b7..358b1b343d 100644 --- a/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance/cloudsqlmysqlcreateinstance.go +++ b/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance/cloudsqlmysqlcreateinstance.go @@ -17,12 +17,14 @@ package cloudsqlmysqlcreateinstance import ( "context" "fmt" + "net/http" "strings" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" sqladmin "google.golang.org/api/sqladmin/v1" ) @@ -121,33 +123,33 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok { - return nil, fmt.Errorf("missing 'project' parameter") + return nil, util.NewAgentError("missing 'project' parameter", nil) } name, ok := paramsMap["name"].(string) if !ok { - return nil, fmt.Errorf("missing 'name' parameter") + return nil, util.NewAgentError("missing 'name' parameter", nil) } dbVersion, ok := paramsMap["databaseVersion"].(string) if !ok { - return nil, fmt.Errorf("missing 'databaseVersion' parameter") + return nil, util.NewAgentError("missing 'databaseVersion' parameter", nil) } rootPassword, ok := paramsMap["rootPassword"].(string) if !ok { - return nil, fmt.Errorf("missing 'rootPassword' parameter") + return nil, util.NewAgentError("missing 'rootPassword' parameter", nil) } editionPreset, ok := paramsMap["editionPreset"].(string) if !ok { - return nil, fmt.Errorf("missing 'editionPreset' parameter") + return nil, util.NewAgentError("missing 'editionPreset' parameter", nil) } settings := sqladmin.Settings{} @@ -165,10 +167,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para settings.DataDiskSizeGb = 100 settings.DataDiskType = "PD_SSD" default: - return nil, fmt.Errorf("invalid 'editionPreset': %q. Must be either 'Production' or 'Development'", editionPreset) + return nil, util.NewAgentError(fmt.Sprintf("invalid 'editionPreset': %q. Must be either 'Production' or 'Development'", editionPreset), nil) } - return source.CreateInstance(ctx, project, name, dbVersion, rootPassword, settings, string(accessToken)) + resp, err := source.CreateInstance(ctx, project, name, dbVersion, rootPassword, settings, string(accessToken)) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances/cloudsqlpgcreateinstances.go b/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances/cloudsqlpgcreateinstances.go index 93639c84d5..e0e0a9f3f8 100644 --- a/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances/cloudsqlpgcreateinstances.go +++ b/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances/cloudsqlpgcreateinstances.go @@ -17,12 +17,14 @@ package cloudsqlpgcreateinstances import ( "context" "fmt" + "net/http" "strings" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" sqladmin "google.golang.org/api/sqladmin/v1" ) @@ -121,33 +123,33 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok { - return nil, fmt.Errorf("missing 'project' parameter") + return nil, util.NewAgentError("missing 'project' parameter", nil) } name, ok := paramsMap["name"].(string) if !ok { - return nil, fmt.Errorf("missing 'name' parameter") + return nil, util.NewAgentError("missing 'name' parameter", nil) } dbVersion, ok := paramsMap["databaseVersion"].(string) if !ok { - return nil, fmt.Errorf("missing 'databaseVersion' parameter") + return nil, util.NewAgentError("missing 'databaseVersion' parameter", nil) } rootPassword, ok := paramsMap["rootPassword"].(string) if !ok { - return nil, fmt.Errorf("missing 'rootPassword' parameter") + return nil, util.NewAgentError("missing 'rootPassword' parameter", nil) } editionPreset, ok := paramsMap["editionPreset"].(string) if !ok { - return nil, fmt.Errorf("missing 'editionPreset' parameter") + return nil, util.NewAgentError("missing 'editionPreset' parameter", nil) } settings := sqladmin.Settings{} @@ -165,9 +167,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para settings.DataDiskSizeGb = 100 settings.DataDiskType = "PD_SSD" default: - return nil, fmt.Errorf("invalid 'editionPreset': %q. Must be either 'Production' or 'Development'", editionPreset) + return nil, util.NewAgentError(fmt.Sprintf("invalid 'editionPreset': %q. Must be either 'Production' or 'Development'", editionPreset), nil) } - return source.CreateInstance(ctx, project, name, dbVersion, rootPassword, settings, string(accessToken)) + resp, err := source.CreateInstance(ctx, project, name, dbVersion, rootPassword, settings, string(accessToken)) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudsqlpg/cloudsqlpgupgradeprecheck/cloudsqlpgupgradeprecheck.go b/internal/tools/cloudsqlpg/cloudsqlpgupgradeprecheck/cloudsqlpgupgradeprecheck.go index f5d57750e6..4f00896fb7 100644 --- a/internal/tools/cloudsqlpg/cloudsqlpgupgradeprecheck/cloudsqlpgupgradeprecheck.go +++ b/internal/tools/cloudsqlpg/cloudsqlpgupgradeprecheck/cloudsqlpgupgradeprecheck.go @@ -17,12 +17,14 @@ package cloudsqlpgupgradeprecheck import ( "context" "fmt" + "net/http" "time" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" sqladmin "google.golang.org/api/sqladmin/v1" ) @@ -132,31 +134,31 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok || project == "" { - return nil, fmt.Errorf("missing or empty 'project' parameter") + return nil, util.NewAgentError("missing or empty 'project' parameter", nil) } instanceName, ok := paramsMap["instance"].(string) if !ok || instanceName == "" { - return nil, fmt.Errorf("missing or empty 'instance' parameter") + return nil, util.NewAgentError("missing or empty 'instance' parameter", nil) } targetVersion, ok := paramsMap["targetDatabaseVersion"].(string) if !ok || targetVersion == "" { // This should not happen due to the default value - return nil, fmt.Errorf("missing or empty 'targetDatabaseVersion' parameter") + return nil, util.NewAgentError("missing or empty 'targetDatabaseVersion' parameter", nil) } service, err := source.GetService(ctx, string(accessToken)) if err != nil { - return nil, fmt.Errorf("failed to get HTTP client from source: %w", err) + return nil, util.ProcessGcpError(err) } reqBody := &sqladmin.InstancesPreCheckMajorVersionUpgradeRequest{ @@ -168,7 +170,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para call := service.Instances.PreCheckMajorVersionUpgrade(project, instanceName, reqBody).Context(ctx) op, err := call.Do() if err != nil { - return nil, fmt.Errorf("failed to start pre-check operation: %w", err) + return nil, util.ProcessGcpError(err) } const pollTimeout = 20 * time.Second @@ -177,7 +179,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para for time.Now().Before(cutoffTime) { currentOp, err := service.Operations.Get(project, op.Name).Context(ctx).Do() if err != nil { - return nil, fmt.Errorf("failed to get operation status: %w", err) + return nil, util.ProcessGcpError(err) } if currentOp.Status == "DONE" { @@ -186,7 +188,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if currentOp.Error.Errors[0].Code != "" { errMsg = fmt.Sprintf("%s (Code: %s)", errMsg, currentOp.Error.Errors[0].Code) } - return nil, fmt.Errorf("%s", errMsg) + return nil, util.NewClientServerError(errMsg, http.StatusInternalServerError, fmt.Errorf("pre-check operation failed with error: %s", errMsg)) } var preCheckItems []*sqladmin.PreCheckResponse @@ -199,7 +201,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para select { case <-ctx.Done(): - return nil, ctx.Err() + return nil, util.NewClientServerError("timed out waiting for operation", http.StatusRequestTimeout, ctx.Err()) case <-time.After(5 * time.Second): } } diff --git a/internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql.go b/internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql.go index 7bd4f07345..efc7c0962e 100644 --- a/internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql.go +++ b/internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql.go @@ -17,6 +17,7 @@ package cockroachdbexecutesql import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -104,26 +105,27 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { - return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"]) + return nil, util.NewAgentError(fmt.Sprintf("parameter 'sql' is required, unable to cast %v", paramsMap["sql"]), nil) } + logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } - logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", t.Type, sql)) results, err := source.Query(ctx, sql) if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) + return nil, util.ProcessGeneralError(fmt.Errorf("unable to execute query: %w", err)) } defer results.Close() @@ -133,7 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para for results.Next() { v, err := results.Values() if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) + return nil, util.NewClientServerError("unable to parse row", http.StatusInternalServerError, err) } row := orderedmap.Row{} for i, f := range fields { @@ -143,16 +145,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } if err := results.Err(); err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) + return nil, util.ProcessGeneralError(fmt.Errorf("error during row iteration: %w", err)) } return out, nil } -func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { - return parameters.ParseParams(t.Parameters, data, claims) -} - func (t Tool) EmbedParams(ctx context.Context, params parameters.ParamValues, models map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { return params, nil } diff --git a/internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas.go b/internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas.go index 2a5c2dbc8e..0f834ec416 100644 --- a/internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas.go +++ b/internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas.go @@ -17,12 +17,14 @@ package cockroachdblistschemas import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/sources/cockroachdb" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5" ) @@ -116,15 +118,15 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } results, err := source.Query(ctx, listSchemasStatement) if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) + return nil, util.ProcessGeneralError(fmt.Errorf("unable to execute query: %w", err)) } defer results.Close() @@ -134,7 +136,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para for results.Next() { values, err := results.Values() if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) + return nil, util.NewClientServerError("unable to parse row", http.StatusInternalServerError, err) } rowMap := make(map[string]any) for i, field := range fields { @@ -144,16 +146,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } if err := results.Err(); err != nil { - return nil, fmt.Errorf("error reading query results: %w", err) + return nil, util.ProcessGeneralError(fmt.Errorf("error reading query results: %w", err)) } return out, nil } -func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { - return parameters.ParseParams(t.AllParams, data, claims) -} - func (t Tool) EmbedParams(ctx context.Context, params parameters.ParamValues, models map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { return params, nil } diff --git a/internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables.go b/internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables.go index 254ee3b658..d99e0297d9 100644 --- a/internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables.go +++ b/internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables.go @@ -17,12 +17,14 @@ package cockroachdblisttables import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/sources/cockroachdb" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5" ) @@ -179,26 +181,26 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() tableNames, ok := paramsMap["table_names"].(string) if !ok { - return nil, fmt.Errorf("invalid 'table_names' parameter; expected a string") + return nil, util.NewAgentError("invalid 'table_names' parameter; expected a string", nil) } outputFormat, _ := paramsMap["output_format"].(string) if outputFormat != "simple" && outputFormat != "detailed" { - return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat) + return nil, util.NewAgentError(fmt.Sprintf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat), nil) } results, err := source.Query(ctx, listTablesStatement, tableNames, outputFormat) if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) + return nil, util.ProcessGeneralError(fmt.Errorf("unable to execute query: %w", err)) } defer results.Close() @@ -208,7 +210,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para for results.Next() { values, err := results.Values() if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) + return nil, util.NewClientServerError("unable to parse row", http.StatusInternalServerError, err) } rowMap := make(map[string]any) for i, field := range fields { @@ -218,16 +220,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } if err := results.Err(); err != nil { - return nil, fmt.Errorf("error reading query results: %w", err) + return nil, util.ProcessGeneralError(fmt.Errorf("error reading query results: %w", err)) } return out, nil } -func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { - return parameters.ParseParams(t.AllParams, data, claims) -} - func (t Tool) EmbedParams(ctx context.Context, params parameters.ParamValues, models map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { return params, nil } diff --git a/internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql.go b/internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql.go index 33b1830545..7dbf0017a7 100644 --- a/internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql.go +++ b/internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql.go @@ -17,12 +17,14 @@ package cockroachdbsql import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/sources/cockroachdb" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5" @@ -110,26 +112,26 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError(fmt.Sprintf("unable to resolve template params: %v", err), err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError(fmt.Sprintf("unable to extract standard params: %v", err), err) } sliceParams := newParams.AsSlice() results, err := source.Query(ctx, newStatement, sliceParams...) if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) + return nil, util.ProcessGeneralError(fmt.Errorf("unable to execute query: %w", err)) } defer results.Close() @@ -139,7 +141,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para for results.Next() { v, err := results.Values() if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) + return nil, util.NewClientServerError("unable to parse row", http.StatusInternalServerError, err) } row := orderedmap.Row{} for i, f := range fields { @@ -149,16 +151,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } if err := results.Err(); err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) + return nil, util.ProcessGeneralError(fmt.Errorf("unable to execute query: %w", err)) } return out, nil } -func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { - return parameters.ParseParams(t.AllParams, data, claims) -} - func (t Tool) EmbedParams(ctx context.Context, params parameters.ParamValues, models map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { return params, nil } diff --git a/internal/tools/couchbase/couchbase.go b/internal/tools/couchbase/couchbase.go index b15515d623..439d7cb053 100644 --- a/internal/tools/couchbase/couchbase.go +++ b/internal/tools/couchbase/couchbase.go @@ -17,12 +17,14 @@ package couchbase import ( "context" "fmt" + "net/http" "github.com/couchbase/gocb/v2" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -58,7 +60,6 @@ type Config struct { TemplateParameters parameters.Parameters `yaml:"templateParameters"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -72,7 +73,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup t := Tool{ Config: cfg, AllParams: allParameters, @@ -82,7 +82,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return t, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -96,23 +95,28 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } namedParamsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, namedParamsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, namedParamsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } - return source.RunSQL(newStatement, newParams) + + resp, err := source.RunSQL(newStatement, newParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/dataform/dataformcompilelocal/dataformcompilelocal.go b/internal/tools/dataform/dataformcompilelocal/dataformcompilelocal.go index 61b77e79cf..6f8bc383d1 100644 --- a/internal/tools/dataform/dataformcompilelocal/dataformcompilelocal.go +++ b/internal/tools/dataform/dataformcompilelocal/dataformcompilelocal.go @@ -24,6 +24,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -86,18 +87,19 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { paramsMap := params.AsMap() projectDir, ok := paramsMap["project_dir"].(string) if !ok || projectDir == "" { - return nil, fmt.Errorf("error casting 'project_dir' to string or invalid value") + return nil, util.NewAgentError("error casting 'project_dir' to string or invalid value", nil) } cmd := exec.CommandContext(ctx, "dataform", "compile", projectDir, "--json") output, err := cmd.CombinedOutput() if err != nil { - return nil, fmt.Errorf("error executing dataform compile: %w\nOutput: %s", err, string(output)) + // Compilation failures are considered AgentErrors (invalid user code/project) + return nil, util.NewAgentError(fmt.Sprintf("error executing dataform compile: %v\nOutput: %s", err, string(output)), err) } return strings.TrimSpace(string(output)), nil diff --git a/internal/tools/dataplex/dataplexlookupentry/dataplexlookupentry.go b/internal/tools/dataplex/dataplexlookupentry/dataplexlookupentry.go index fdfe656eb8..3cbf3fa7ea 100644 --- a/internal/tools/dataplex/dataplexlookupentry/dataplexlookupentry.go +++ b/internal/tools/dataplex/dataplexlookupentry/dataplexlookupentry.go @@ -17,12 +17,14 @@ package dataplexlookupentry import ( "context" "fmt" + "net/http" dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -110,10 +112,10 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() @@ -122,10 +124,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para view, _ := paramsMap["view"].(int) aspectTypeSlice, err := parameters.ConvertAnySliceToTyped(paramsMap["aspectTypes"].([]any), "string") if err != nil { - return nil, fmt.Errorf("can't convert aspectTypes to array of strings: %s", err) + return nil, util.NewAgentError(fmt.Sprintf("can't convert aspectTypes to array of strings: %s", err), err) } aspectTypes := aspectTypeSlice.([]string) - return source.LookupEntry(ctx, name, view, aspectTypes, entry) + resp, err := source.LookupEntry(ctx, name, view, aspectTypes, entry) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/dataplex/dataplexsearchaspecttypes/dataplexsearchaspecttypes.go b/internal/tools/dataplex/dataplexsearchaspecttypes/dataplexsearchaspecttypes.go index b57b598fca..7489d1a1cc 100644 --- a/internal/tools/dataplex/dataplexsearchaspecttypes/dataplexsearchaspecttypes.go +++ b/internal/tools/dataplex/dataplexsearchaspecttypes/dataplexsearchaspecttypes.go @@ -17,12 +17,14 @@ package dataplexsearchaspecttypes import ( "context" "fmt" + "net/http" "cloud.google.com/go/dataplex/apiv1/dataplexpb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -93,16 +95,29 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() - query, _ := paramsMap["query"].(string) - pageSize, _ := paramsMap["pageSize"].(int) - orderBy, _ := paramsMap["orderBy"].(string) - return source.SearchAspectTypes(ctx, query, pageSize, orderBy) + query, ok := paramsMap["query"].(string) + if !ok { + return nil, util.NewAgentError(fmt.Sprintf("error casting 'query' parameter: %v", paramsMap["query"]), nil) + } + pageSize, ok := paramsMap["pageSize"].(int) + if !ok { + return nil, util.NewAgentError(fmt.Sprintf("error casting 'pageSize' parameter: %v", paramsMap["pageSize"]), nil) + } + orderBy, ok := paramsMap["orderBy"].(string) + if !ok { + return nil, util.NewAgentError(fmt.Sprintf("error casting 'orderBy' parameter: %v", paramsMap["orderBy"]), nil) + } + resp, err := source.SearchAspectTypes(ctx, query, pageSize, orderBy) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/dataplex/dataplexsearchentries/dataplexsearchentries.go b/internal/tools/dataplex/dataplexsearchentries/dataplexsearchentries.go index b3dafbff98..230ef8356b 100644 --- a/internal/tools/dataplex/dataplexsearchentries/dataplexsearchentries.go +++ b/internal/tools/dataplex/dataplexsearchentries/dataplexsearchentries.go @@ -17,12 +17,14 @@ package dataplexsearchentries import ( "context" "fmt" + "net/http" "cloud.google.com/go/dataplex/apiv1/dataplexpb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -93,16 +95,29 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() - query, _ := paramsMap["query"].(string) - pageSize, _ := paramsMap["pageSize"].(int) - orderBy, _ := paramsMap["orderBy"].(string) - return source.SearchEntries(ctx, query, pageSize, orderBy) + query, ok := paramsMap["query"].(string) + if !ok { + return nil, util.NewAgentError(fmt.Sprintf("error casting 'query' parameter: %v", paramsMap["query"]), nil) + } + pageSize, ok := paramsMap["pageSize"].(int) + if !ok { + return nil, util.NewAgentError(fmt.Sprintf("error casting 'pageSize' parameter: %v", paramsMap["pageSize"]), nil) + } + orderBy, ok := paramsMap["orderBy"].(string) + if !ok { + return nil, util.NewAgentError(fmt.Sprintf("error casting 'orderBy' parameter: %v", paramsMap["orderBy"]), nil) + } + resp, err := source.SearchEntries(ctx, query, pageSize, orderBy) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/dgraph/dgraph.go b/internal/tools/dgraph/dgraph.go index d5e4cb72bf..fb4d76f1e1 100644 --- a/internal/tools/dgraph/dgraph.go +++ b/internal/tools/dgraph/dgraph.go @@ -17,12 +17,14 @@ package dgraph import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/sources/dgraph" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -91,12 +93,16 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } - return source.RunSQL(t.Statement, params, t.IsQuery, t.Timeout) + resp, err := source.RunSQL(t.Statement, params, t.IsQuery, t.Timeout) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go b/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go index 12387da8b4..3cf828d30e 100644 --- a/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go +++ b/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go @@ -17,9 +17,11 @@ package elasticsearchesql import ( "context" "fmt" + "net/http" "time" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/goccy/go-yaml" @@ -89,10 +91,10 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } var cancel context.CancelFunc @@ -119,11 +121,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para for _, param := range t.Parameters { if param.GetType() == "array" { - return nil, fmt.Errorf("array parameters are not supported yet") + return nil, util.NewAgentError("array parameters are not supported yet", nil) } sqlParams = append(sqlParams, map[string]any{param.GetName(): paramMap[param.GetName()]}) } - return source.RunSQL(ctx, t.Format, query, sqlParams) + resp, err := source.RunSQL(ctx, t.Format, query, sqlParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/firebird/firebirdexecutesql/firebirdexecutesql.go b/internal/tools/firebird/firebirdexecutesql/firebirdexecutesql.go index 40e9195ee7..b1d97f1235 100644 --- a/internal/tools/firebird/firebirdexecutesql/firebirdexecutesql.go +++ b/internal/tools/firebird/firebirdexecutesql/firebirdexecutesql.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -90,25 +91,30 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() - sql, ok := paramsMap["sql"].(string) + sqlStr, ok := paramsMap["sql"].(string) if !ok { - return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to cast parameter 'sql' to string: %v", paramsMap["sql"]), nil) } // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } - logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql)) - return source.RunSQL(ctx, sql, nil) + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sqlStr)) + + resp, err := source.RunSQL(ctx, sqlStr, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/firebird/firebirdsql/firebirdsql.go b/internal/tools/firebird/firebirdsql/firebirdsql.go index 73c455ccb6..fbadc1c2a1 100644 --- a/internal/tools/firebird/firebirdsql/firebirdsql.go +++ b/internal/tools/firebird/firebirdsql/firebirdsql.go @@ -18,12 +18,14 @@ import ( "context" "database/sql" "fmt" + "net/http" "strings" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -98,21 +100,21 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() statement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params: %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params: %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } namedArgs := make([]any, 0, len(newParams)) @@ -127,7 +129,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para namedArgs = append(namedArgs, value) } } - return source.RunSQL(ctx, statement, namedArgs) + + resp, err := source.RunSQL(ctx, statement, namedArgs) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/firestore/firestoreadddocuments/firestoreadddocuments.go b/internal/tools/firestore/firestoreadddocuments/firestoreadddocuments.go index 20c6163335..893948983d 100644 --- a/internal/tools/firestore/firestoreadddocuments/firestoreadddocuments.go +++ b/internal/tools/firestore/firestoreadddocuments/firestoreadddocuments.go @@ -17,13 +17,15 @@ package firestoreadddocuments import ( "context" "fmt" + "net/http" firestoreapi "cloud.google.com/go/firestore" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" + fsUtil "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -128,32 +130,32 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } mapParams := params.AsMap() // Get collection path collectionPath, ok := mapParams[collectionPathKey].(string) if !ok || collectionPath == "" { - return nil, fmt.Errorf("invalid or missing '%s' parameter", collectionPathKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter", collectionPathKey), nil) } // Validate collection path - if err := util.ValidateCollectionPath(collectionPath); err != nil { - return nil, fmt.Errorf("invalid collection path: %w", err) + if err := fsUtil.ValidateCollectionPath(collectionPath); err != nil { + return nil, util.NewAgentError(fmt.Sprintf("invalid collection path: %v", err), err) } // Get document data documentDataRaw, ok := mapParams[documentDataKey] if !ok { - return nil, fmt.Errorf("invalid or missing '%s' parameter", documentDataKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter", documentDataKey), nil) } // Convert the document data from JSON format to Firestore format // The client is passed to handle referenceValue types - documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient()) + documentData, err := fsUtil.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient()) if err != nil { - return nil, fmt.Errorf("failed to convert document data: %w", err) + return nil, util.NewAgentError(fmt.Sprintf("failed to convert document data: %v", err), err) } // Get return document data flag @@ -161,7 +163,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if val, ok := mapParams[returnDocumentDataKey].(bool); ok { returnData = val } - return source.AddDocuments(ctx, collectionPath, documentData, returnData) + resp, err := source.AddDocuments(ctx, collectionPath, documentData, returnData) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/firestore/firestoredeletedocuments/firestoredeletedocuments.go b/internal/tools/firestore/firestoredeletedocuments/firestoredeletedocuments.go index 1610c6a038..22bdf47e5a 100644 --- a/internal/tools/firestore/firestoredeletedocuments/firestoredeletedocuments.go +++ b/internal/tools/firestore/firestoredeletedocuments/firestoredeletedocuments.go @@ -17,13 +17,15 @@ package firestoredeletedocuments import ( "context" "fmt" + "net/http" firestoreapi "cloud.google.com/go/firestore" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" + fsUtil "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -94,39 +96,43 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } mapParams := params.AsMap() documentPathsRaw, ok := mapParams[documentPathsKey].([]any) if !ok { - return nil, fmt.Errorf("invalid or missing '%s' parameter; expected an array", documentPathsKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected an array", documentPathsKey), nil) } if len(documentPathsRaw) == 0 { - return nil, fmt.Errorf("'%s' parameter cannot be empty", documentPathsKey) + return nil, util.NewAgentError(fmt.Sprintf("'%s' parameter cannot be empty", documentPathsKey), nil) } // Use ConvertAnySliceToTyped to convert the slice typedSlice, err := parameters.ConvertAnySliceToTyped(documentPathsRaw, "string") if err != nil { - return nil, fmt.Errorf("failed to convert document paths: %w", err) + return nil, util.NewAgentError(fmt.Sprintf("failed to convert document paths: %v", err), err) } documentPaths, ok := typedSlice.([]string) if !ok { - return nil, fmt.Errorf("unexpected type conversion error for document paths") + return nil, util.NewAgentError("unexpected type conversion error for document paths", nil) } // Validate each document path for i, path := range documentPaths { - if err := util.ValidateDocumentPath(path); err != nil { - return nil, fmt.Errorf("invalid document path at index %d: %w", i, err) + if err := fsUtil.ValidateDocumentPath(path); err != nil { + return nil, util.NewAgentError(fmt.Sprintf("invalid document path at index %d: %v", i, err), err) } } - return source.DeleteDocuments(ctx, documentPaths) + resp, err := source.DeleteDocuments(ctx, documentPaths) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/firestore/firestoregetdocuments/firestoregetdocuments.go b/internal/tools/firestore/firestoregetdocuments/firestoregetdocuments.go index 5ccc68ef9b..71c4e181a6 100644 --- a/internal/tools/firestore/firestoregetdocuments/firestoregetdocuments.go +++ b/internal/tools/firestore/firestoregetdocuments/firestoregetdocuments.go @@ -17,13 +17,15 @@ package firestoregetdocuments import ( "context" "fmt" + "net/http" firestoreapi "cloud.google.com/go/firestore" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" + fsUtil "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -94,40 +96,44 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } mapParams := params.AsMap() documentPathsRaw, ok := mapParams[documentPathsKey].([]any) if !ok { - return nil, fmt.Errorf("invalid or missing '%s' parameter; expected an array", documentPathsKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected an array", documentPathsKey), nil) } if len(documentPathsRaw) == 0 { - return nil, fmt.Errorf("'%s' parameter cannot be empty", documentPathsKey) + return nil, util.NewAgentError(fmt.Sprintf("'%s' parameter cannot be empty", documentPathsKey), nil) } // Use ConvertAnySliceToTyped to convert the slice typedSlice, err := parameters.ConvertAnySliceToTyped(documentPathsRaw, "string") if err != nil { - return nil, fmt.Errorf("failed to convert document paths: %w", err) + return nil, util.NewAgentError(fmt.Sprintf("failed to convert document paths: %v", err), err) } documentPaths, ok := typedSlice.([]string) if !ok { - return nil, fmt.Errorf("unexpected type conversion error for document paths") + return nil, util.NewAgentError("unexpected type conversion error for document paths", nil) } // Validate each document path for i, path := range documentPaths { - if err := util.ValidateDocumentPath(path); err != nil { - return nil, fmt.Errorf("invalid document path at index %d: %w", i, err) + if err := fsUtil.ValidateDocumentPath(path); err != nil { + return nil, util.NewAgentError(fmt.Sprintf("invalid document path at index %d: %v", i, err), err) } } - return source.GetDocuments(ctx, documentPaths) + resp, err := source.GetDocuments(ctx, documentPaths) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/firestore/firestoregetrules/firestoregetrules.go b/internal/tools/firestore/firestoregetrules/firestoregetrules.go index 13453c4e30..8740a93888 100644 --- a/internal/tools/firestore/firestoregetrules/firestoregetrules.go +++ b/internal/tools/firestore/firestoregetrules/firestoregetrules.go @@ -17,11 +17,13 @@ package firestoregetrules import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/firebaserules/v1" ) @@ -92,12 +94,16 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } - return source.GetRules(ctx) + resp, err := source.GetRules(ctx) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/firestore/firestorelistcollections/firestorelistcollections.go b/internal/tools/firestore/firestorelistcollections/firestorelistcollections.go index c4bcc451e0..62db352013 100644 --- a/internal/tools/firestore/firestorelistcollections/firestorelistcollections.go +++ b/internal/tools/firestore/firestorelistcollections/firestorelistcollections.go @@ -17,13 +17,15 @@ package firestorelistcollections import ( "context" "fmt" + "net/http" firestoreapi "cloud.google.com/go/firestore" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" + fsUtil "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -95,10 +97,10 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } mapParams := params.AsMap() @@ -107,11 +109,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para parentPath, _ := mapParams[parentPathKey].(string) if parentPath != "" { // Validate parent document path - if err := util.ValidateDocumentPath(parentPath); err != nil { - return nil, fmt.Errorf("invalid parent document path: %w", err) + if err := fsUtil.ValidateDocumentPath(parentPath); err != nil { + return nil, util.NewAgentError(fmt.Sprintf("invalid parent document path: %v", err), err) } } - return source.ListCollections(ctx, parentPath) + resp, err := source.ListCollections(ctx, parentPath) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/firestore/firestorequery/firestorequery.go b/internal/tools/firestore/firestorequery/firestorequery.go index 21ecd1294e..15d6e1b842 100644 --- a/internal/tools/firestore/firestorequery/firestorequery.go +++ b/internal/tools/firestore/firestorequery/firestorequery.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "strconv" "strings" @@ -26,7 +27,8 @@ import ( "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" + fsUtil "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -158,16 +160,16 @@ var validOperators = map[string]bool{ } // Invoke executes the Firestore query based on the provided parameters -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() // Process collection path with template substitution collectionPath, err := parameters.PopulateTemplate("collectionPath", t.CollectionPath, paramsMap) if err != nil { - return nil, fmt.Errorf("failed to process collection path: %w", err) + return nil, util.NewAgentError(fmt.Sprintf("failed to process collection path: %v", err), err) } var filter firestoreapi.EntityFilter @@ -176,13 +178,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Apply template substitution to filters filtersJSON, err := parameters.PopulateTemplateWithJSON("filters", t.Filters, paramsMap) if err != nil { - return nil, fmt.Errorf("failed to process filters template: %w", err) + return nil, util.NewAgentError(fmt.Sprintf("failed to process filters template: %v", err), err) } // Parse the simplified filter format var simplifiedFilter SimplifiedFilter if err := json.Unmarshal([]byte(filtersJSON), &simplifiedFilter); err != nil { - return nil, fmt.Errorf("failed to parse filters: %w", err) + return nil, util.NewAgentError(fmt.Sprintf("failed to parse filters: %v", err), err) } // Convert simplified filter to Firestore filter @@ -191,17 +193,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Process and apply ordering orderBy, err := t.getOrderBy(paramsMap) if err != nil { - return nil, err + return nil, util.NewAgentError(fmt.Sprintf("failed to process order by: %v", err), err) } // Process select fields selectFields, err := t.processSelectFields(paramsMap) if err != nil { - return nil, err + return nil, util.NewAgentError(fmt.Sprintf("failed to process select fields: %v", err), err) } // Process and apply limit limit, err := t.getLimit(paramsMap) if err != nil { - return nil, err + return nil, util.NewAgentError(fmt.Sprintf("failed to process limit: %v", err), err) } // prevent panic when accessing orderBy incase it is nil @@ -215,10 +217,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Build the query query, err := source.BuildQuery(collectionPath, filter, selectFields, orderByField, orderByDirection, limit, t.AnalyzeQuery) if err != nil { - return nil, err + return nil, util.ProcessGcpError(err) } // Execute the query and return results - return source.ExecuteQuery(ctx, query, t.AnalyzeQuery) + resp, err := source.ExecuteQuery(ctx, query, t.AnalyzeQuery) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } // convertToFirestoreFilter converts simplified filter format to Firestore EntityFilter @@ -255,7 +261,7 @@ func (t Tool) convertToFirestoreFilter(source compatibleSource, filter Simplifie if filter.Field != "" && filter.Op != "" && filter.Value != nil { if validOperators[filter.Op] { // Convert the value using the Firestore native JSON converter - convertedValue, err := util.JSONToFirestoreValue(filter.Value, source.FirestoreClient()) + convertedValue, err := fsUtil.JSONToFirestoreValue(filter.Value, source.FirestoreClient()) if err != nil { // If conversion fails, use the original value convertedValue = filter.Value @@ -367,7 +373,7 @@ func (t Tool) getLimit(params map[string]any) (int, error) { if processedValue != "" { parsedLimit, err := strconv.Atoi(processedValue) if err != nil { - return 0, fmt.Errorf("failed to parse limit value '%s': %w", processedValue, err) + return 0, err } limit = parsedLimit } diff --git a/internal/tools/firestore/firestorequerycollection/firestorequerycollection.go b/internal/tools/firestore/firestorequerycollection/firestorequerycollection.go index 9f8eb29007..65c44e8e0c 100644 --- a/internal/tools/firestore/firestorequerycollection/firestorequerycollection.go +++ b/internal/tools/firestore/firestorequerycollection/firestorequerycollection.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "strings" firestoreapi "cloud.google.com/go/firestore" @@ -25,7 +26,8 @@ import ( "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" + fsUtil "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -230,16 +232,16 @@ func (o *OrderByConfig) GetDirection() firestoreapi.Direction { } // Invoke executes the Firestore query based on the provided parameters -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } // Parse parameters queryParams, err := t.parseQueryParameters(params) if err != nil { - return nil, err + return nil, util.NewAgentError(fmt.Sprintf("failed to parse query parameters: %v", err), err) } var filter firestoreapi.EntityFilter @@ -270,9 +272,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Build the query query, err := source.BuildQuery(queryParams.CollectionPath, filter, nil, orderByField, orderByDirection, queryParams.Limit, queryParams.AnalyzeQuery) if err != nil { - return nil, err + return nil, util.ProcessGcpError(err) } - return source.ExecuteQuery(ctx, query, queryParams.AnalyzeQuery) + resp, err := source.ExecuteQuery(ctx, query, queryParams.AnalyzeQuery) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } // queryParameters holds all parsed query parameters @@ -295,7 +301,7 @@ func (t Tool) parseQueryParameters(params parameters.ParamValues) (*queryParamet } // Validate collection path - if err := util.ValidateCollectionPath(collectionPath); err != nil { + if err := fsUtil.ValidateCollectionPath(collectionPath); err != nil { return nil, fmt.Errorf("invalid collection path: %w", err) } diff --git a/internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument.go b/internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument.go index b059d28e91..85588e6217 100644 --- a/internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument.go +++ b/internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument.go @@ -17,6 +17,7 @@ package firestoreupdatedocument import ( "context" "fmt" + "net/http" "strings" firestoreapi "cloud.google.com/go/firestore" @@ -24,7 +25,8 @@ import ( "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" + fsUtil "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -138,10 +140,10 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } mapParams := params.AsMap() @@ -149,18 +151,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Get document path documentPath, ok := mapParams[documentPathKey].(string) if !ok || documentPath == "" { - return nil, fmt.Errorf("invalid or missing '%s' parameter", documentPathKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter", documentPathKey), nil) } // Validate document path - if err := util.ValidateDocumentPath(documentPath); err != nil { - return nil, fmt.Errorf("invalid document path: %w", err) + if err := fsUtil.ValidateDocumentPath(documentPath); err != nil { + return nil, util.NewAgentError(fmt.Sprintf("invalid document path: %v", err), err) } // Get document data documentDataRaw, ok := mapParams[documentDataKey] if !ok { - return nil, fmt.Errorf("invalid or missing '%s' parameter", documentDataKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter", documentDataKey), nil) } // Get update mask if provided @@ -170,11 +172,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Use ConvertAnySliceToTyped to convert the slice typedSlice, err := parameters.ConvertAnySliceToTyped(updateMaskArray, "string") if err != nil { - return nil, fmt.Errorf("failed to convert update mask: %w", err) + return nil, util.NewAgentError(fmt.Sprintf("failed to convert update mask: %v", err), err) } updatePaths, ok = typedSlice.([]string) if !ok { - return nil, fmt.Errorf("unexpected type conversion error for update mask") + return nil, util.NewAgentError("unexpected type conversion error for update mask", nil) } } } @@ -184,15 +186,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if len(updatePaths) > 0 { // Convert document data without delete markers - dataMap, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient()) + dataMap, err := fsUtil.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient()) if err != nil { - return nil, fmt.Errorf("failed to convert document data: %w", err) + return nil, util.NewAgentError(fmt.Sprintf("failed to convert document data: %v", err), err) } // Ensure it's a map dataMapTyped, ok := dataMap.(map[string]interface{}) if !ok { - return nil, fmt.Errorf("document data must be a map") + return nil, util.NewAgentError("document data must be a map", nil) } for _, path := range updatePaths { @@ -210,9 +212,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } else { // Update all fields in the document data (merge) - documentData, err = util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient()) + documentData, err = fsUtil.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient()) if err != nil { - return nil, fmt.Errorf("failed to convert document data: %w", err) + return nil, util.NewAgentError(fmt.Sprintf("failed to convert document data: %v", err), err) } } @@ -221,7 +223,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if val, ok := mapParams[returnDocumentDataKey].(bool); ok { returnData = val } - return source.UpdateDocument(ctx, documentPath, updates, documentData, returnData) + resp, err := source.UpdateDocument(ctx, documentPath, updates, documentData, returnData) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } // getFieldValue retrieves a value from a nested map using a dot-separated path diff --git a/internal/tools/firestore/firestorevalidaterules/firestorevalidaterules.go b/internal/tools/firestore/firestorevalidaterules/firestorevalidaterules.go index 617ad80c5b..12f981b14d 100644 --- a/internal/tools/firestore/firestorevalidaterules/firestorevalidaterules.go +++ b/internal/tools/firestore/firestorevalidaterules/firestorevalidaterules.go @@ -17,11 +17,13 @@ package firestorevalidaterules import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/firebaserules/v1" ) @@ -106,10 +108,10 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } mapParams := params.AsMap() @@ -117,9 +119,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Get source parameter sourceParam, ok := mapParams[sourceKey].(string) if !ok || sourceParam == "" { - return nil, fmt.Errorf("invalid or missing '%s' parameter", sourceKey) + return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter", sourceKey), nil) } - return source.ValidateRules(ctx, sourceParam) + resp, err := source.ValidateRules(ctx, sourceParam) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/http/http.go b/internal/tools/http/http.go index 0b3b49e383..c7be1c185b 100644 --- a/internal/tools/http/http.go +++ b/internal/tools/http/http.go @@ -17,18 +17,18 @@ import ( "bytes" "context" "fmt" + "maps" "net/http" "net/url" "slices" "strings" - - "maps" "text/template" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -98,7 +98,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) maps.Copy(combinedHeaders, cfg.Headers) // Create a slice for all parameters - allParameters := slices.Concat(cfg.PathParams, cfg.BodyParams, cfg.HeaderParams, cfg.QueryParams) + allParameters := slices.Concat(cfg.PathParams, cfg.QueryParams, cfg.BodyParams, cfg.HeaderParams) // Verify no duplicate parameter names err := parameters.CheckDuplicateParameters(allParameters) @@ -226,10 +226,10 @@ func getHeaders(headerParams parameters.Parameters, defaultHeaders map[string]st return allHeaders, nil } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() @@ -237,27 +237,35 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Calculate request body requestBody, err := getRequestBody(t.BodyParams, t.RequestBody, paramsMap) if err != nil { - return nil, fmt.Errorf("error populating request body: %s", err) + return nil, util.NewAgentError("error populating request body", err) } // Calculate URL urlString, err := getURL(source.HttpBaseURL(), t.Path, t.PathParams, t.QueryParams, source.HttpQueryParams(), paramsMap) if err != nil { - return nil, fmt.Errorf("error populating path parameters: %s", err) + return nil, util.NewAgentError("error populating path parameters", err) } - req, _ := http.NewRequest(string(t.Method), urlString, strings.NewReader(requestBody)) + req, err := http.NewRequestWithContext(ctx, string(t.Method), urlString, strings.NewReader(requestBody)) + if err != nil { + return nil, util.NewClientServerError("error creating http request", http.StatusInternalServerError, err) + } // Calculate request headers allHeaders, err := getHeaders(t.HeaderParams, t.Headers, paramsMap) if err != nil { - return nil, fmt.Errorf("error populating request headers: %s", err) + return nil, util.NewAgentError("error populating request headers", err) } // Set request headers for k, v := range allHeaders { req.Header.Set(k, v) } - return source.RunRequest(req) + + resp, err := source.RunRequest(req) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/looker/lookeradddashboardelement/lookeradddashboardelement.go b/internal/tools/looker/lookeradddashboardelement/lookeradddashboardelement.go index 5b1103c102..5c6a6e5880 100644 --- a/internal/tools/looker/lookeradddashboardelement/lookeradddashboardelement.go +++ b/internal/tools/looker/lookeradddashboardelement/lookeradddashboardelement.go @@ -16,6 +16,7 @@ package lookeradddashboardelement import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -134,58 +135,74 @@ var ( visType string = "vis" ) -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } logger.DebugContext(ctx, "params = ", params) wq, err := lookercommon.ProcessQueryArgs(ctx, params) if err != nil { - return nil, fmt.Errorf("error building query request: %w", err) + return nil, util.NewAgentError("error building query request", err) } paramsMap := params.AsMap() - dashboard_id := paramsMap["dashboard_id"].(string) - title := paramsMap["title"].(string) - visConfig := paramsMap["vis_config"].(map[string]any) + dashboard_id, ok := paramsMap["dashboard_id"].(string) + if !ok { + return nil, util.NewAgentError("dashboard_id parameter missing or invalid", nil) + } + + title, ok := paramsMap["title"].(string) + if !ok { + title = "" + } + + visConfig, ok := paramsMap["vis_config"].(map[string]any) + if !ok { + visConfig = make(map[string]any) + } wq.VisConfig = &visConfig sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } qresp, err := sdk.CreateQuery(*wq, "id", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making create query request: %w", err) + return nil, util.ProcessGeneralError(err) } dashFilters := []any{} if v, ok := paramsMap["dashboard_filters"]; ok { if v != nil { - dashFilters = paramsMap["dashboard_filters"].([]any) + if df, ok := v.([]any); ok { + dashFilters = df + } } } var filterables []v4.ResultMakerFilterables for _, m := range dashFilters { - f := m.(map[string]any) + f, ok := m.(map[string]any) + if !ok { + return nil, util.NewAgentError("invalid dashboard filter structure", nil) + } name, ok := f["dashboard_filter_name"].(string) if !ok { - return nil, fmt.Errorf("error processing dashboard filter: %w", err) + return nil, util.NewAgentError("error processing dashboard filter: missing dashboard_filter_name", nil) } field, ok := f["field"].(string) if !ok { - return nil, fmt.Errorf("error processing dashboard filter: %w", err) + return nil, util.NewAgentError("error processing dashboard filter: missing field", nil) } listener := v4.ResultMakerFilterablesListen{ DashboardFilterName: &name, @@ -233,7 +250,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para resp, err := sdk.CreateDashboardElement(req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making create dashboard element request: %w", err) + return nil, util.ProcessGeneralError(err) } logger.DebugContext(ctx, "resp = %v", resp) diff --git a/internal/tools/looker/lookeradddashboardfilter/lookeradddashboardfilter.go b/internal/tools/looker/lookeradddashboardfilter/lookeradddashboardfilter.go index e3da8838f8..71ca790850 100644 --- a/internal/tools/looker/lookeradddashboardfilter/lookeradddashboardfilter.go +++ b/internal/tools/looker/lookeradddashboardfilter/lookeradddashboardfilter.go @@ -16,6 +16,7 @@ package lookeradddashboardfilter import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -128,33 +129,54 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } logger.DebugContext(ctx, "params = ", params) paramsMap := params.AsMap() - dashboard_id := paramsMap["dashboard_id"].(string) - name := paramsMap["name"].(string) - title := paramsMap["title"].(string) - filterType := paramsMap["filter_type"].(string) + dashboard_id, ok := paramsMap["dashboard_id"].(string) + if !ok { + return nil, util.NewAgentError("dashboard_id parameter missing or invalid", nil) + } + name, ok := paramsMap["name"].(string) + if !ok { + return nil, util.NewAgentError("name parameter missing or invalid", nil) + } + title, ok := paramsMap["title"].(string) + if !ok { + return nil, util.NewAgentError("title parameter missing or invalid", nil) + } + filterType, ok := paramsMap["filter_type"].(string) + if !ok { + return nil, util.NewAgentError("filter_type parameter missing or invalid", nil) + } + switch filterType { case "date_filter": case "number_filter": case "string_filter": case "field_filter": default: - return nil, fmt.Errorf("invalid filter type: %s. Must be one of date_filter, number_filter, string_filter, field_filter", filterType) + return nil, util.NewAgentError(fmt.Sprintf("invalid filter type: %s. Must be one of date_filter, number_filter, string_filter, field_filter", filterType), nil) + } + + allowMultipleValues, ok := paramsMap["allow_multiple_values"].(bool) + if !ok { + // defaults should handle this, but safe fallback + allowMultipleValues = true + } + required, ok := paramsMap["required"].(bool) + if !ok { + required = false } - allowMultipleValues := paramsMap["allow_multiple_values"].(bool) - required := paramsMap["required"].(bool) req := v4.WriteCreateDashboardFilter{ DashboardId: dashboard_id, @@ -165,9 +187,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para Required: &required, } - if v, ok := paramsMap["default_value"]; ok { - if v != nil { - defaultValue := paramsMap["default_value"].(string) + if v, ok := paramsMap["default_value"]; ok && v != nil { + if defaultValue, ok := v.(string); ok { req.DefaultValue = &defaultValue } } @@ -175,15 +196,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if filterType == "field_filter" { model, ok := paramsMap["model"].(string) if !ok || model == "" { - return nil, fmt.Errorf("model must be specified for field_filter type") + return nil, util.NewAgentError("model must be specified for field_filter type", nil) } explore, ok := paramsMap["explore"].(string) if !ok || explore == "" { - return nil, fmt.Errorf("explore must be specified for field_filter type") + return nil, util.NewAgentError("explore must be specified for field_filter type", nil) } dimension, ok := paramsMap["dimension"].(string) if !ok || dimension == "" { - return nil, fmt.Errorf("dimension must be specified for field_filter type") + return nil, util.NewAgentError("dimension must be specified for field_filter type", nil) } req.Model = &model @@ -193,12 +214,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } resp, err := sdk.CreateDashboardFilter(req, "name", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making create dashboard filter request: %s", err) + return nil, util.ProcessGeneralError(err) } logger.DebugContext(ctx, "resp = %v", resp) diff --git a/internal/tools/looker/lookerconversationalanalytics/lookerconversationalanalytics.go b/internal/tools/looker/lookerconversationalanalytics/lookerconversationalanalytics.go index 3c548abc49..3eb28c4d5e 100644 --- a/internal/tools/looker/lookerconversationalanalytics/lookerconversationalanalytics.go +++ b/internal/tools/looker/lookerconversationalanalytics/lookerconversationalanalytics.go @@ -215,10 +215,10 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } var tokenStr string @@ -226,11 +226,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Get credentials for the API call // Use cloud-platform token source for Gemini Data Analytics API if t.TokenSource == nil { - return nil, fmt.Errorf("cloud-platform token source is missing") + return nil, util.NewClientServerError("cloud-platform token source is missing", http.StatusInternalServerError, nil) } token, err := t.TokenSource.Token() if err != nil { - return nil, fmt.Errorf("failed to get token from cloud-platform token source: %w", err) + return nil, util.NewClientServerError("failed to get token from cloud-platform token source", http.StatusInternalServerError, err) } tokenStr = token.AccessToken @@ -286,7 +286,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Call the streaming API response, err := getStream(ctx, caURL, payload, headers) if err != nil { - return nil, fmt.Errorf("failed to get response from conversational analytics API: %w", err) + return nil, util.NewClientServerError("failed to get response from conversational analytics API", http.StatusInternalServerError, err) } return response, nil diff --git a/internal/tools/looker/lookercreateprojectfile/lookercreateprojectfile.go b/internal/tools/looker/lookercreateprojectfile/lookercreateprojectfile.go index bcbdc02014..830df321f7 100644 --- a/internal/tools/looker/lookercreateprojectfile/lookercreateprojectfile.go +++ b/internal/tools/looker/lookercreateprojectfile/lookercreateprojectfile.go @@ -16,12 +16,14 @@ package lookercreateprojectfile import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/looker-open-source/sdk-codegen/go/rtl" @@ -110,29 +112,29 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } mapParams := params.AsMap() projectId, ok := mapParams["project_id"].(string) if !ok { - return nil, fmt.Errorf("'project_id' must be a string, got %T", mapParams["project_id"]) + return nil, util.NewAgentError(fmt.Sprintf("'project_id' must be a string, got %T", mapParams["project_id"]), nil) } filePath, ok := mapParams["file_path"].(string) if !ok { - return nil, fmt.Errorf("'file_path' must be a string, got %T", mapParams["file_path"]) + return nil, util.NewAgentError(fmt.Sprintf("'file_path' must be a string, got %T", mapParams["file_path"]), nil) } fileContent, ok := mapParams["file_content"].(string) if !ok { - return nil, fmt.Errorf("'file_content' must be a string, got %T", mapParams["file_content"]) + return nil, util.NewAgentError(fmt.Sprintf("'file_content' must be a string, got %T", mapParams["file_content"]), nil) } req := lookercommon.FileContent{ @@ -142,7 +144,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para err = lookercommon.CreateProjectFile(sdk, projectId, req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making create_project_file request: %s", err) + return nil, util.ProcessGeneralError(err) } data := make(map[string]any) diff --git a/internal/tools/looker/lookerdeleteprojectfile/lookerdeleteprojectfile.go b/internal/tools/looker/lookerdeleteprojectfile/lookerdeleteprojectfile.go index 7644adee6a..741b6bb220 100644 --- a/internal/tools/looker/lookerdeleteprojectfile/lookerdeleteprojectfile.go +++ b/internal/tools/looker/lookerdeleteprojectfile/lookerdeleteprojectfile.go @@ -16,12 +16,14 @@ package lookerdeleteprojectfile import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/looker-open-source/sdk-codegen/go/rtl" @@ -111,30 +113,30 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } mapParams := params.AsMap() projectId, ok := mapParams["project_id"].(string) if !ok { - return nil, fmt.Errorf("'project_id' must be a string, got %T", mapParams["project_id"]) + return nil, util.NewAgentError(fmt.Sprintf("'project_id' must be a string, got %T", mapParams["project_id"]), nil) } filePath, ok := mapParams["file_path"].(string) if !ok { - return nil, fmt.Errorf("'file_path' must be a string, got %T", mapParams["file_path"]) + return nil, util.NewAgentError(fmt.Sprintf("'file_path' must be a string, got %T", mapParams["file_path"]), nil) } err = lookercommon.DeleteProjectFile(sdk, projectId, filePath, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making delete_project_file request: %s", err) + return nil, util.ProcessGeneralError(err) } data := make(map[string]any) diff --git a/internal/tools/looker/lookerdevmode/lookerdevmode.go b/internal/tools/looker/lookerdevmode/lookerdevmode.go index ea16d4a7ad..274062052f 100644 --- a/internal/tools/looker/lookerdevmode/lookerdevmode.go +++ b/internal/tools/looker/lookerdevmode/lookerdevmode.go @@ -16,6 +16,7 @@ package lookerdevmode import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -60,7 +61,6 @@ type Config struct { Annotations *tools.ToolAnnotations `yaml:"annotations,omitempty"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -81,7 +81,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) - // finish tool setup return Tool{ Config: cfg, Parameters: params, @@ -94,7 +93,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -108,25 +106,25 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } mapParams := params.AsMap() devMode, ok := mapParams["devMode"].(bool) if !ok { - return nil, fmt.Errorf("'devMode' must be a boolean, got %T", mapParams["devMode"]) + return nil, util.NewAgentError(fmt.Sprintf("'devMode' must be a boolean, got %T", mapParams["devMode"]), nil) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } var devModeString string if devMode { @@ -139,7 +137,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := sdk.UpdateSession(req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error setting/resetting dev mode: %w", err) + return nil, util.ProcessGeneralError(err) } logger.DebugContext(ctx, "result = ", resp) diff --git a/internal/tools/looker/lookergenerateembedurl/lookergenerateembedurl.go b/internal/tools/looker/lookergenerateembedurl/lookergenerateembedurl.go index 9ffc6f2f8e..908aae9e18 100644 --- a/internal/tools/looker/lookergenerateembedurl/lookergenerateembedurl.go +++ b/internal/tools/looker/lookergenerateembedurl/lookergenerateembedurl.go @@ -4,7 +4,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -17,6 +17,7 @@ package lookergenerateembedurl import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -114,15 +115,15 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } paramsMap := params.AsMap() embedType := paramsMap["type"].(string) @@ -138,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } forceLogoutLogin := true @@ -151,7 +152,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para logger.ErrorContext(ctx, "Making request %v", req) resp, err := sdk.CreateEmbedUrlAsMe(req, nil) if err != nil { - return nil, fmt.Errorf("error making create_embed_url_as_me request: %s", err) + return nil, util.ProcessGeneralError(err) } logger.ErrorContext(ctx, "Got response %v", resp) diff --git a/internal/tools/looker/lookergetconnectiondatabases/lookergetconnectiondatabases.go b/internal/tools/looker/lookergetconnectiondatabases/lookergetconnectiondatabases.go index 81c156f78b..c62fcef1f1 100644 --- a/internal/tools/looker/lookergetconnectiondatabases/lookergetconnectiondatabases.go +++ b/internal/tools/looker/lookergetconnectiondatabases/lookergetconnectiondatabases.go @@ -16,11 +16,13 @@ package lookergetconnectiondatabases import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/looker-open-source/sdk-codegen/go/rtl" @@ -107,27 +109,26 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } mapParams := params.AsMap() conn, ok := mapParams["conn"].(string) if !ok { - return nil, fmt.Errorf("'conn' must be a string, got %T", mapParams["conn"]) + return nil, util.NewAgentError(fmt.Sprintf("'conn' must be a string, got %T", mapParams["conn"]), nil) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } resp, err := sdk.ConnectionDatabases(conn, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_connection_databases request: %s", err) + return nil, util.ProcessGeneralError(err) } - //logger.DebugContext(ctx, "Got response of %v\n", resp) return resp, nil } diff --git a/internal/tools/looker/lookergetconnections/lookergetconnections.go b/internal/tools/looker/lookergetconnections/lookergetconnections.go index e223df0c04..585b30f8ca 100644 --- a/internal/tools/looker/lookergetconnections/lookergetconnections.go +++ b/internal/tools/looker/lookergetconnections/lookergetconnections.go @@ -16,6 +16,7 @@ package lookergetconnections import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -107,24 +108,24 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } resp, err := sdk.AllConnections("name, dialect(name), database, schema", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_connections request: %s", err) + return nil, util.ProcessGeneralError(err) } var data []any @@ -140,7 +141,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } conn, err := sdk.ConnectionFeatures(*v.Name, "multiple_databases", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_connection_features request: %s", err) + return nil, util.ProcessGeneralError(err) } vMap["supports_multiple_databases"] = *conn.MultipleDatabases logger.DebugContext(ctx, "Converted to %v\n", vMap) diff --git a/internal/tools/looker/lookergetconnectionschemas/lookergetconnectionschemas.go b/internal/tools/looker/lookergetconnectionschemas/lookergetconnectionschemas.go index e7385d1c64..3528b08a2e 100644 --- a/internal/tools/looker/lookergetconnectionschemas/lookergetconnectionschemas.go +++ b/internal/tools/looker/lookergetconnectionschemas/lookergetconnectionschemas.go @@ -16,11 +16,13 @@ package lookergetconnectionschemas import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/looker-open-source/sdk-codegen/go/rtl" @@ -108,22 +110,22 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } mapParams := params.AsMap() conn, ok := mapParams["conn"].(string) if !ok { - return nil, fmt.Errorf("'conn' must be a string, got %T", mapParams["conn"]) + return nil, util.NewAgentError(fmt.Sprintf("'conn' must be a string, got %T", mapParams["conn"]), nil) } db, _ := mapParams["db"].(string) sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } req := v4.RequestConnectionSchemas{ ConnectionName: conn, @@ -133,7 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := sdk.ConnectionSchemas(req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_connection_schemas request: %s", err) + return nil, util.ProcessGeneralError(err) } return resp, nil } diff --git a/internal/tools/looker/lookergetconnectiontablecolumns/lookergetconnectiontablecolumns.go b/internal/tools/looker/lookergetconnectiontablecolumns/lookergetconnectiontablecolumns.go index 263034e73a..61e3974c1a 100644 --- a/internal/tools/looker/lookergetconnectiontablecolumns/lookergetconnectiontablecolumns.go +++ b/internal/tools/looker/lookergetconnectiontablecolumns/lookergetconnectiontablecolumns.go @@ -8,7 +8,7 @@ // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY, either express or implied. +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package lookergetconnectiontablecolumns @@ -16,6 +16,7 @@ package lookergetconnectiontablecolumns import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -111,34 +112,34 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } mapParams := params.AsMap() conn, ok := mapParams["conn"].(string) if !ok { - return nil, fmt.Errorf("'conn' must be a string, got %T", mapParams["conn"]) + return nil, util.NewAgentError(fmt.Sprintf("'conn' must be a string, got %T", mapParams["conn"]), nil) } db, _ := mapParams["db"].(string) schema, ok := mapParams["schema"].(string) if !ok { - return nil, fmt.Errorf("'schema' must be a string, got %T", mapParams["schema"]) + return nil, util.NewAgentError(fmt.Sprintf("'schema' must be a string, got %T", mapParams["schema"]), nil) } tables, ok := mapParams["tables"].(string) if !ok { - return nil, fmt.Errorf("'tables' must be a string, got %T", mapParams["tables"]) + return nil, util.NewAgentError(fmt.Sprintf("'tables' must be a string, got %T", mapParams["tables"]), nil) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } req := v4.RequestConnectionColumns{ ConnectionName: conn, @@ -150,7 +151,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := sdk.ConnectionColumns(req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_connection_table_columns request: %s", err) + return nil, util.ProcessGeneralError(err) } var data []any for _, t := range resp { diff --git a/internal/tools/looker/lookergetconnectiontables/lookergetconnectiontables.go b/internal/tools/looker/lookergetconnectiontables/lookergetconnectiontables.go index 6d993f96a4..0a26997efe 100644 --- a/internal/tools/looker/lookergetconnectiontables/lookergetconnectiontables.go +++ b/internal/tools/looker/lookergetconnectiontables/lookergetconnectiontables.go @@ -16,6 +16,7 @@ package lookergetconnectiontables import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -110,30 +111,30 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } mapParams := params.AsMap() conn, ok := mapParams["conn"].(string) if !ok { - return nil, fmt.Errorf("'conn' must be a string, got %T", mapParams["conn"]) + return nil, util.NewAgentError(fmt.Sprintf("'conn' must be a string, got %T", mapParams["conn"]), nil) } db, _ := mapParams["db"].(string) schema, ok := mapParams["schema"].(string) if !ok { - return nil, fmt.Errorf("'schema' must be a string, got %T", mapParams["schema"]) + return nil, util.NewAgentError(fmt.Sprintf("'schema' must be a string, got %T", mapParams["schema"]), nil) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } req := v4.RequestConnectionTables{ ConnectionName: conn, @@ -144,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := sdk.ConnectionTables(req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_connection_tables request: %s", err) + return nil, util.ProcessGeneralError(err) } var data []any for _, s := range resp { diff --git a/internal/tools/looker/lookergetdashboards/lookergetdashboards.go b/internal/tools/looker/lookergetdashboards/lookergetdashboards.go index 64ac53783e..a4df73fbf7 100644 --- a/internal/tools/looker/lookergetdashboards/lookergetdashboards.go +++ b/internal/tools/looker/lookergetdashboards/lookergetdashboards.go @@ -16,6 +16,7 @@ package lookergetdashboards import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -116,15 +117,15 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } paramsMap := params.AsMap() title := paramsMap["title"].(string) @@ -142,7 +143,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } req := v4.RequestSearchDashboards{ Title: title_ptr, @@ -153,7 +154,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para logger.ErrorContext(ctx, "Making request %v", req) resp, err := sdk.SearchDashboards(req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_dashboards request: %s", err) + return nil, util.ProcessGeneralError(err) } logger.ErrorContext(ctx, "Got response %v", resp) var data []any diff --git a/internal/tools/looker/lookergetdimensions/lookergetdimensions.go b/internal/tools/looker/lookergetdimensions/lookergetdimensions.go index 3494207128..3373fe7db4 100644 --- a/internal/tools/looker/lookergetdimensions/lookergetdimensions.go +++ b/internal/tools/looker/lookergetdimensions/lookergetdimensions.go @@ -16,6 +16,7 @@ package lookergetdimensions import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -109,24 +110,24 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } model, explore, err := lookercommon.ProcessFieldArgs(ctx, params) if err != nil { - return nil, fmt.Errorf("error processing model or explore: %w", err) + return nil, util.NewAgentError("error processing model or explore", err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } fields := lookercommon.DimensionsFields req := v4.RequestLookmlModelExplore{ @@ -136,16 +137,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := sdk.LookmlModelExplore(req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_dimensions request: %w", err) + return nil, util.ProcessGeneralError(err) } if err := lookercommon.CheckLookerExploreFields(&resp); err != nil { - return nil, fmt.Errorf("error processing get_dimensions response: %w", err) + return nil, util.ProcessGeneralError(err) } data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Dimensions, source.LookerShowHiddenFields()) if err != nil { - return nil, fmt.Errorf("error extracting get_dimensions response: %w", err) + return nil, util.ProcessGeneralError(err) } logger.DebugContext(ctx, "data = ", data) diff --git a/internal/tools/looker/lookergetexplores/lookergetexplores.go b/internal/tools/looker/lookergetexplores/lookergetexplores.go index ea5c83e45f..2a58ffdcde 100644 --- a/internal/tools/looker/lookergetexplores/lookergetexplores.go +++ b/internal/tools/looker/lookergetexplores/lookergetexplores.go @@ -16,6 +16,7 @@ package lookergetexplores import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -109,29 +110,29 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } mapParams := params.AsMap() model, ok := mapParams["model"].(string) if !ok { - return nil, fmt.Errorf("'model' must be a string, got %T", mapParams["model"]) + return nil, util.NewAgentError(fmt.Sprintf("'model' must be a string, got %T", mapParams["model"]), nil) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } resp, err := sdk.LookmlModel(model, "explores(name,description,label,group_label,hidden)", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_explores request: %s", err) + return nil, util.ProcessGeneralError(err) } var data []any diff --git a/internal/tools/looker/lookergetfilters/lookergetfilters.go b/internal/tools/looker/lookergetfilters/lookergetfilters.go index 49e86d338d..20db21c0b5 100644 --- a/internal/tools/looker/lookergetfilters/lookergetfilters.go +++ b/internal/tools/looker/lookergetfilters/lookergetfilters.go @@ -16,6 +16,7 @@ package lookergetfilters import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -109,25 +110,25 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } model, explore, err := lookercommon.ProcessFieldArgs(ctx, params) if err != nil { - return nil, fmt.Errorf("error processing model or explore: %w", err) + return nil, util.NewAgentError(fmt.Sprintf("error processing model or explore: %v", err), err) } fields := lookercommon.FiltersFields sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error getting sdk: %v", err), http.StatusInternalServerError, err) } req := v4.RequestLookmlModelExplore{ LookmlModelName: *model, @@ -136,16 +137,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := sdk.LookmlModelExplore(req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_filters request: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error making get_filters request: %v", err), http.StatusInternalServerError, err) } if err := lookercommon.CheckLookerExploreFields(&resp); err != nil { - return nil, fmt.Errorf("error processing get_filters response: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error processing get_filters response: %v", err), http.StatusInternalServerError, err) } data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Filters, source.LookerShowHiddenFields()) if err != nil { - return nil, fmt.Errorf("error extracting get_filters response: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error extracting get_filters response: %v", err), http.StatusInternalServerError, err) } logger.DebugContext(ctx, "data = ", data) diff --git a/internal/tools/looker/lookergetlooks/lookergetlooks.go b/internal/tools/looker/lookergetlooks/lookergetlooks.go index 877a5c8586..f6642c5772 100644 --- a/internal/tools/looker/lookergetlooks/lookergetlooks.go +++ b/internal/tools/looker/lookergetlooks/lookergetlooks.go @@ -16,6 +16,7 @@ package lookergetlooks import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -116,23 +117,29 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } paramsMap := params.AsMap() - title := paramsMap["title"].(string) + title, ok := paramsMap["title"].(string) + if !ok { + return nil, util.NewAgentError("missing or invalid 'title' parameter", nil) + } title_ptr := &title if *title_ptr == "" { title_ptr = nil } - desc := paramsMap["desc"].(string) + desc, ok := paramsMap["desc"].(string) + if !ok { + return nil, util.NewAgentError("missing or invalid 'desc' parameter", nil) + } desc_ptr := &desc if *desc_ptr == "" { desc_ptr = nil @@ -142,7 +149,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error getting sdk: %v", err), http.StatusInternalServerError, err) } req := v4.RequestSearchLooks{ Title: title_ptr, @@ -152,7 +159,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := sdk.SearchLooks(req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_looks request: %s", err) + return nil, util.NewClientServerError(fmt.Sprintf("error making get_looks request: %s", err), http.StatusInternalServerError, err) } var data []any diff --git a/internal/tools/looker/lookergetmeasures/lookergetmeasures.go b/internal/tools/looker/lookergetmeasures/lookergetmeasures.go index 5d5ed52e75..d326c55909 100644 --- a/internal/tools/looker/lookergetmeasures/lookergetmeasures.go +++ b/internal/tools/looker/lookergetmeasures/lookergetmeasures.go @@ -16,6 +16,7 @@ package lookergetmeasures import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -109,25 +110,25 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } model, explore, err := lookercommon.ProcessFieldArgs(ctx, params) if err != nil { - return nil, fmt.Errorf("error processing model or explore: %w", err) + return nil, util.NewAgentError(fmt.Sprintf("error processing model or explore: %v", err), err) } fields := lookercommon.MeasuresFields sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error getting sdk: %v", err), http.StatusInternalServerError, err) } req := v4.RequestLookmlModelExplore{ LookmlModelName: *model, @@ -136,16 +137,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := sdk.LookmlModelExplore(req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_measures request: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error making get_measures request: %v", err), http.StatusInternalServerError, err) } if err := lookercommon.CheckLookerExploreFields(&resp); err != nil { - return nil, fmt.Errorf("error processing get_measures response: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error processing get_measures response: %v", err), http.StatusInternalServerError, err) } data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Measures, source.LookerShowHiddenFields()) if err != nil { - return nil, fmt.Errorf("error extracting get_measures response: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error extracting get_measures response: %v", err), http.StatusInternalServerError, err) } logger.DebugContext(ctx, "data = ", data) diff --git a/internal/tools/looker/lookergetmodels/lookergetmodels.go b/internal/tools/looker/lookergetmodels/lookergetmodels.go index 2caf1d1efc..221570f652 100644 --- a/internal/tools/looker/lookergetmodels/lookergetmodels.go +++ b/internal/tools/looker/lookergetmodels/lookergetmodels.go @@ -16,6 +16,7 @@ package lookergetmodels import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -108,15 +109,15 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } excludeEmpty := false @@ -125,7 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error getting sdk: %v", err), http.StatusInternalServerError, err) } req := v4.RequestAllLookmlModels{ ExcludeEmpty: &excludeEmpty, @@ -134,7 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := sdk.AllLookmlModels(req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_models request: %s", err) + return nil, util.NewClientServerError(fmt.Sprintf("error making get_models request: %s", err), http.StatusInternalServerError, err) } var data []any diff --git a/internal/tools/looker/lookergetparameters/lookergetparameters.go b/internal/tools/looker/lookergetparameters/lookergetparameters.go index 13d6e9b8d0..172c6d0cdf 100644 --- a/internal/tools/looker/lookergetparameters/lookergetparameters.go +++ b/internal/tools/looker/lookergetparameters/lookergetparameters.go @@ -7,7 +7,7 @@ // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, +// distributed under the License is distributed under an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. @@ -16,6 +16,7 @@ package lookergetparameters import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -109,25 +110,25 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } model, explore, err := lookercommon.ProcessFieldArgs(ctx, params) if err != nil { - return nil, fmt.Errorf("error processing model or explore: %w", err) + return nil, util.NewAgentError(fmt.Sprintf("error processing model or explore: %v", err), err) } fields := lookercommon.ParametersFields sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error getting sdk: %v", err), http.StatusInternalServerError, err) } req := v4.RequestLookmlModelExplore{ LookmlModelName: *model, @@ -136,16 +137,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := sdk.LookmlModelExplore(req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_parameters request: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error making get_parameters request: %v", err), http.StatusInternalServerError, err) } if err := lookercommon.CheckLookerExploreFields(&resp); err != nil { - return nil, fmt.Errorf("error processing get_parameters response: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error processing get_parameters response: %v", err), http.StatusInternalServerError, err) } data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Parameters, source.LookerShowHiddenFields()) if err != nil { - return nil, fmt.Errorf("error extracting get_parameters response: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error extracting get_parameters response: %v", err), http.StatusInternalServerError, err) } logger.DebugContext(ctx, "data = ", data) diff --git a/internal/tools/looker/lookergetprojectfile/lookergetprojectfile.go b/internal/tools/looker/lookergetprojectfile/lookergetprojectfile.go index bc2ced3e2b..378111b754 100644 --- a/internal/tools/looker/lookergetprojectfile/lookergetprojectfile.go +++ b/internal/tools/looker/lookergetprojectfile/lookergetprojectfile.go @@ -16,6 +16,7 @@ package lookergetprojectfile import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -110,35 +111,35 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error getting sdk: %v", err), http.StatusInternalServerError, err) } mapParams := params.AsMap() projectId, ok := mapParams["project_id"].(string) if !ok { - return nil, fmt.Errorf("'project_id' must be a string, got %T", mapParams["project_id"]) + return nil, util.NewAgentError(fmt.Sprintf("'project_id' must be a string, got %T", mapParams["project_id"]), nil) } filePath, ok := mapParams["file_path"].(string) if !ok { - return nil, fmt.Errorf("'file_path' must be a string, got %T", mapParams["file_path"]) + return nil, util.NewAgentError(fmt.Sprintf("'file_path' must be a string, got %T", mapParams["file_path"]), nil) } resp, err := lookercommon.GetProjectFileContent(sdk, projectId, filePath, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_project_file request: %s", err) + return nil, util.NewClientServerError(fmt.Sprintf("error making get_project_file request: %s", err), http.StatusInternalServerError, err) } logger.DebugContext(ctx, "Got response of %v\n", resp) diff --git a/internal/tools/looker/lookergetprojectfiles/lookergetprojectfiles.go b/internal/tools/looker/lookergetprojectfiles/lookergetprojectfiles.go index 9ba42e5916..b2a05ff626 100644 --- a/internal/tools/looker/lookergetprojectfiles/lookergetprojectfiles.go +++ b/internal/tools/looker/lookergetprojectfiles/lookergetprojectfiles.go @@ -16,6 +16,7 @@ package lookergetprojectfiles import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -108,31 +109,31 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error getting sdk: %v", err), http.StatusInternalServerError, err) } mapParams := params.AsMap() projectId, ok := mapParams["project_id"].(string) if !ok { - return nil, fmt.Errorf("'project_id' must be a string, got %T", mapParams["project_id"]) + return nil, util.NewAgentError(fmt.Sprintf("'project_id' must be a string, got %T", mapParams["project_id"]), nil) } resp, err := sdk.AllProjectFiles(projectId, "", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_project_files request: %s", err) + return nil, util.NewClientServerError(fmt.Sprintf("error making get_project_files request: %s", err), http.StatusInternalServerError, err) } var data []any diff --git a/internal/tools/looker/lookergetprojects/lookergetprojects.go b/internal/tools/looker/lookergetprojects/lookergetprojects.go index ae93d87790..74118951a6 100644 --- a/internal/tools/looker/lookergetprojects/lookergetprojects.go +++ b/internal/tools/looker/lookergetprojects/lookergetprojects.go @@ -16,6 +16,7 @@ package lookergetprojects import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -107,25 +108,25 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error getting sdk: %v", err), http.StatusInternalServerError, err) } resp, err := sdk.AllProjects("id,name", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_models request: %s", err) + return nil, util.NewClientServerError(fmt.Sprintf("error making get_models request: %s", err), http.StatusInternalServerError, err) } var data []any diff --git a/internal/tools/looker/lookerhealthanalyze/lookerhealthanalyze.go b/internal/tools/looker/lookerhealthanalyze/lookerhealthanalyze.go index 140842f0b3..ae28c7e5f9 100644 --- a/internal/tools/looker/lookerhealthanalyze/lookerhealthanalyze.go +++ b/internal/tools/looker/lookerhealthanalyze/lookerhealthanalyze.go @@ -17,6 +17,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "regexp" "strings" @@ -125,20 +126,20 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error getting sdk: %v", err), http.StatusInternalServerError, err) } paramsMap := params.AsMap() @@ -159,7 +160,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para action, ok := paramsMap["action"].(string) if !ok { - return nil, fmt.Errorf("action parameter not found") + return nil, util.NewAgentError("action parameter not found", nil) } switch action { @@ -167,7 +168,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para projectId, _ := paramsMap["project"].(string) result, err := analyzeTool.projects(ctx, projectId) if err != nil { - return nil, fmt.Errorf("error analyzing projects: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error analyzing projects: %v", err), http.StatusInternalServerError, err) } logger.DebugContext(ctx, "result = ", result) return result, nil @@ -176,7 +177,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para modelName, _ := paramsMap["model"].(string) result, err := analyzeTool.models(ctx, projectName, modelName) if err != nil { - return nil, fmt.Errorf("error analyzing models: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error analyzing models: %v", err), http.StatusInternalServerError, err) } logger.DebugContext(ctx, "result = ", result) return result, nil @@ -185,12 +186,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para exploreName, _ := paramsMap["explore"].(string) result, err := analyzeTool.explores(ctx, modelName, exploreName) if err != nil { - return nil, fmt.Errorf("error analyzing explores: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error analyzing explores: %v", err), http.StatusInternalServerError, err) } logger.DebugContext(ctx, "result = ", result) return result, nil default: - return nil, fmt.Errorf("unknown action: %s", action) + return nil, util.NewAgentError(fmt.Sprintf("unknown action: %s", action), nil) } } @@ -231,23 +232,23 @@ type analyzeTool struct { minQueries int } -func (t *analyzeTool) projects(ctx context.Context, id string) ([]map[string]interface{}, error) { +func (t *analyzeTool) projects(ctx context.Context, id string) ([]map[string]interface{}, util.ToolboxError) { logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } var projects []*v4.Project if id != "" { p, err := t.SdkClient.Project(id, "", nil) if err != nil { - return nil, fmt.Errorf("error fetching project %s: %w", id, err) + return nil, util.NewClientServerError(fmt.Sprintf("error fetching project %s: %v", id, err), http.StatusInternalServerError, err) } projects = append(projects, &p) } else { allProjects, err := t.SdkClient.AllProjects("", nil) if err != nil { - return nil, fmt.Errorf("error fetching all projects: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error fetching all projects: %v", err), http.StatusInternalServerError, err) } for i := range allProjects { projects = append(projects, &allProjects[i]) @@ -262,7 +263,7 @@ func (t *analyzeTool) projects(ctx context.Context, id string) ([]map[string]int projectFiles, err := t.SdkClient.AllProjectFiles(pID, "", nil) if err != nil { - return nil, fmt.Errorf("error fetching files for project %s: %w", pName, err) + return nil, util.NewClientServerError(fmt.Sprintf("error fetching files for project %s: %v", pName, err), http.StatusInternalServerError, err) } modelCount := 0 @@ -297,21 +298,21 @@ func (t *analyzeTool) projects(ctx context.Context, id string) ([]map[string]int return results, nil } -func (t *analyzeTool) models(ctx context.Context, project, model string) ([]map[string]interface{}, error) { +func (t *analyzeTool) models(ctx context.Context, project, model string) ([]map[string]interface{}, util.ToolboxError) { logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } logger.InfoContext(ctx, "Analyzing models...") usedModels, err := t.getUsedModels(ctx) if err != nil { - return nil, err + return nil, util.NewClientServerError("error fetching used models", http.StatusInternalServerError, err) } lookmlModels, err := t.SdkClient.AllLookmlModels(v4.RequestAllLookmlModels{}, nil) if err != nil { - return nil, fmt.Errorf("error fetching LookML models: %w", err) + return nil, util.NewClientServerError("error fetching LookML models", http.StatusInternalServerError, err) } var results []map[string]interface{} @@ -356,7 +357,7 @@ func (t *analyzeTool) getUsedModels(ctx context.Context) (map[string]int, error) } raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", nil) if err != nil { - return nil, err + return nil, util.NewClientServerError(fmt.Sprintf("error running inline query for used models: %v", err), http.StatusInternalServerError, err) } var data []map[string]interface{} @@ -371,7 +372,7 @@ func (t *analyzeTool) getUsedModels(ctx context.Context) (map[string]int, error) return results, nil } -func (t *analyzeTool) getUsedExploreFields(ctx context.Context, model, explore string) (map[string]int, error) { +func (t *analyzeTool) getUsedExploreFields(ctx context.Context, model, explore string) (map[string]int, util.ToolboxError) { limit := "5000" query := &v4.WriteQuery{ Model: "system__activity", @@ -388,7 +389,7 @@ func (t *analyzeTool) getUsedExploreFields(ctx context.Context, model, explore s } raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", nil) if err != nil { - return nil, err + return nil, util.NewClientServerError(fmt.Sprintf("error running inline query for used explore fields: %v", err), http.StatusInternalServerError, err) } var data []map[string]interface{} @@ -418,16 +419,16 @@ func (t *analyzeTool) getUsedExploreFields(ctx context.Context, model, explore s return results, nil } -func (t *analyzeTool) explores(ctx context.Context, model, explore string) ([]map[string]interface{}, error) { +func (t *analyzeTool) explores(ctx context.Context, model, explore string) ([]map[string]interface{}, util.ToolboxError) { logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } logger.InfoContext(ctx, "Analyzing explores...") lookmlModels, err := t.SdkClient.AllLookmlModels(v4.RequestAllLookmlModels{}, nil) if err != nil { - return nil, fmt.Errorf("error fetching LookML models: %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("error fetching LookML models: %v", err), http.StatusInternalServerError, err) } var results []map[string]interface{} @@ -534,7 +535,7 @@ func (t *analyzeTool) explores(ctx context.Context, model, explore string) ([]ma rawQueryCount, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, queryCountQueryBody, "json", nil) if err != nil { - return nil, err + return nil, util.NewClientServerError(fmt.Sprintf("error running inline query for query count: %v", err), http.StatusInternalServerError, err) } queryCount := 0 var data []map[string]interface{} diff --git a/internal/tools/looker/lookerhealthpulse/lookerhealthpulse.go b/internal/tools/looker/lookerhealthpulse/lookerhealthpulse.go index fd5c3ead21..6d0177ecf5 100644 --- a/internal/tools/looker/lookerhealthpulse/lookerhealthpulse.go +++ b/internal/tools/looker/lookerhealthpulse/lookerhealthpulse.go @@ -17,6 +17,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "strings" yaml "github.com/goccy/go-yaml" @@ -116,20 +117,20 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } pulseTool := &pulseTool{ @@ -140,7 +141,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para paramsMap := params.AsMap() action, ok := paramsMap["action"].(string) if !ok { - return nil, fmt.Errorf("action parameter not found") + return nil, util.NewAgentError("action parameter not found", nil) } pulseParams := PulseParams{ @@ -149,7 +150,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para result, err := pulseTool.RunPulse(ctx, source, pulseParams) if err != nil { - return nil, fmt.Errorf("error running pulse: %w", err) + return nil, util.ProcessGeneralError(err) } logger.DebugContext(ctx, "result = ", result) diff --git a/internal/tools/looker/lookerhealthvacuum/lookerhealthvacuum.go b/internal/tools/looker/lookerhealthvacuum/lookerhealthvacuum.go index e0a0580b1d..c3b5658d57 100644 --- a/internal/tools/looker/lookerhealthvacuum/lookerhealthvacuum.go +++ b/internal/tools/looker/lookerhealthvacuum/lookerhealthvacuum.go @@ -17,6 +17,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "regexp" "strings" @@ -125,15 +126,15 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } paramsMap := params.AsMap() @@ -154,21 +155,29 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para action, ok := paramsMap["action"].(string) if !ok { - return nil, fmt.Errorf("action parameter not found") + return nil, util.NewAgentError("action parameter not found", nil) } + var res []map[string]interface{} + switch action { case "models": project, _ := paramsMap["project"].(string) model, _ := paramsMap["model"].(string) - return vacuumTool.models(ctx, project, model) + res, err = vacuumTool.models(ctx, project, model) case "explores": model, _ := paramsMap["model"].(string) explore, _ := paramsMap["explore"].(string) - return vacuumTool.explores(ctx, model, explore) + res, err = vacuumTool.explores(ctx, model, explore) default: - return nil, fmt.Errorf("unknown action: %s", action) + return nil, util.NewAgentError(fmt.Sprintf("unknown action: %s", action), nil) } + + if err != nil { + return nil, util.ProcessGeneralError(err) + } + + return res, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/looker/lookermakedashboard/lookermakedashboard.go b/internal/tools/looker/lookermakedashboard/lookermakedashboard.go index e4da154180..8934e46dc9 100644 --- a/internal/tools/looker/lookermakedashboard/lookermakedashboard.go +++ b/internal/tools/looker/lookermakedashboard/lookermakedashboard.go @@ -17,6 +17,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "slices" yaml "github.com/goccy/go-yaml" @@ -116,21 +117,21 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } logger.DebugContext(ctx, "params = ", params) sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } paramsMap := params.AsMap() @@ -141,19 +142,19 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para mrespFields := "id,personal_folder_id" mresp, err := sdk.Me(mrespFields, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making me request: %s", err) + return nil, util.ProcessGeneralError(err) } if folder == "" { if mresp.PersonalFolderId == nil || *mresp.PersonalFolderId == "" { - return nil, fmt.Errorf("user does not have a personal folder. A folder must be specified") + return nil, util.NewAgentError("user does not have a personal folder. A folder must be specified", nil) } folder = *mresp.PersonalFolderId } dashs, err := sdk.FolderDashboards(folder, "title", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error getting existing dashboards in folder: %s", err) + return nil, util.ProcessGeneralError(err) } dashTitles := []string{} @@ -162,7 +163,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } if slices.Contains(dashTitles, title) { lt, _ := json.Marshal(dashTitles) - return nil, fmt.Errorf("title %s already used in folder. Currently used titles are %v. Make the call again with a unique title", title, string(lt)) + return nil, util.NewAgentError(fmt.Sprintf("title %s already used in folder. Currently used titles are %v. Make the call again with a unique title", title, string(lt)), nil) } wd := v4.WriteDashboard{ @@ -172,7 +173,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := sdk.CreateDashboard(wd, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making create dashboard request: %s", err) + return nil, util.ProcessGeneralError(err) } logger.DebugContext(ctx, "resp = %v", resp) diff --git a/internal/tools/looker/lookermakelook/lookermakelook.go b/internal/tools/looker/lookermakelook/lookermakelook.go index 46d6d61841..3dcaf91716 100644 --- a/internal/tools/looker/lookermakelook/lookermakelook.go +++ b/internal/tools/looker/lookermakelook/lookermakelook.go @@ -17,6 +17,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "slices" yaml "github.com/goccy/go-yaml" @@ -123,25 +124,25 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } logger.DebugContext(ctx, "params = ", params) wq, err := lookercommon.ProcessQueryArgs(ctx, params) if err != nil { - return nil, fmt.Errorf("error building query request: %w", err) + return nil, util.NewAgentError("error building query request", err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } paramsMap := params.AsMap() title := paramsMap["title"].(string) @@ -152,19 +153,19 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para mrespFields := "id,personal_folder_id" mresp, err := sdk.Me(mrespFields, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making me request: %s", err) + return nil, util.ProcessGeneralError(err) } if folder == "" { if mresp.PersonalFolderId == nil || *mresp.PersonalFolderId == "" { - return nil, fmt.Errorf("user does not have a personal folder. A folder must be specified") + return nil, util.NewAgentError("user does not have a personal folder. A folder must be specified", nil) } folder = *mresp.PersonalFolderId } looks, err := sdk.FolderLooks(folder, "title", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error getting existing looks in folder: %s", err) + return nil, util.ProcessGeneralError(err) } lookTitles := []string{} @@ -173,7 +174,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } if slices.Contains(lookTitles, title) { lt, _ := json.Marshal(lookTitles) - return nil, fmt.Errorf("title %s already used in folder. Currently used titles are %v. Make the call again with a unique title", title, string(lt)) + return nil, util.NewAgentError(fmt.Sprintf("title %s already used in folder. Currently used titles are %v. Make the call again with a unique title", title, string(lt)), nil) } wq.VisConfig = &visConfig @@ -181,7 +182,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para qrespFields := "id" qresp, err := sdk.CreateQuery(*wq, qrespFields, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making create query request: %s", err) + return nil, util.ProcessGeneralError(err) } wlwq := v4.WriteLookWithQuery{ @@ -193,7 +194,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := sdk.CreateLook(wlwq, "", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making create look request: %s", err) + return nil, util.ProcessGeneralError(err) } logger.DebugContext(ctx, "resp = %v", resp) diff --git a/internal/tools/looker/lookerquery/lookerquery.go b/internal/tools/looker/lookerquery/lookerquery.go index 1fb6e43f1e..2d099f7519 100644 --- a/internal/tools/looker/lookerquery/lookerquery.go +++ b/internal/tools/looker/lookerquery/lookerquery.go @@ -17,6 +17,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -109,27 +110,27 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } wq, err := lookercommon.ProcessQueryArgs(ctx, params) if err != nil { - return nil, fmt.Errorf("error building WriteQuery request: %w", err) + return nil, util.NewAgentError("error building WriteQuery request", err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } resp, err := lookercommon.RunInlineQuery(ctx, sdk, wq, "json", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making query request: %s", err) + return nil, util.ProcessGeneralError(err) } logger.DebugContext(ctx, "resp = ", resp) @@ -137,7 +138,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para var data []any e := json.Unmarshal([]byte(resp), &data) if e != nil { - return nil, fmt.Errorf("error unmarshaling query response: %s", e) + return nil, util.NewClientServerError("error unmarshaling query response", http.StatusInternalServerError, e) } logger.DebugContext(ctx, "data = ", data) diff --git a/internal/tools/looker/lookerquerysql/lookerquerysql.go b/internal/tools/looker/lookerquerysql/lookerquerysql.go index c796b4e90c..577b67678e 100644 --- a/internal/tools/looker/lookerquerysql/lookerquerysql.go +++ b/internal/tools/looker/lookerquerysql/lookerquerysql.go @@ -16,6 +16,7 @@ package lookerquerysql import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -108,27 +109,27 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } wq, err := lookercommon.ProcessQueryArgs(ctx, params) if err != nil { - return nil, fmt.Errorf("error building query request: %w", err) + return nil, util.NewAgentError("error building query request", err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } resp, err := lookercommon.RunInlineQuery(ctx, sdk, wq, "sql", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making query request: %s", err) + return nil, util.ProcessGeneralError(err) } logger.DebugContext(ctx, "resp = ", resp) diff --git a/internal/tools/looker/lookerqueryurl/lookerqueryurl.go b/internal/tools/looker/lookerqueryurl/lookerqueryurl.go index 566745307c..3b9050b796 100644 --- a/internal/tools/looker/lookerqueryurl/lookerqueryurl.go +++ b/internal/tools/looker/lookerqueryurl/lookerqueryurl.go @@ -16,6 +16,7 @@ package lookerqueryurl import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -115,34 +116,37 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } logger.DebugContext(ctx, "params = ", params) wq, err := lookercommon.ProcessQueryArgs(ctx, params) if err != nil { - return nil, fmt.Errorf("error building query request: %w", err) + return nil, util.NewAgentError("error building query request", err) } paramsMap := params.AsMap() - visConfig := paramsMap["vis_config"].(map[string]any) + visConfig, ok := paramsMap["vis_config"].(map[string]any) + if !ok { + visConfig = make(map[string]any) + } wq.VisConfig = &visConfig sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } respFields := "id,slug,share_url,expanded_share_url" resp, err := sdk.CreateQuery(*wq, respFields, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making query request: %s", err) + return nil, util.ProcessGeneralError(err) } logger.DebugContext(ctx, "resp = ", resp) diff --git a/internal/tools/looker/lookerrundashboard/lookerrundashboard.go b/internal/tools/looker/lookerrundashboard/lookerrundashboard.go index a40ccddc82..20c9433452 100644 --- a/internal/tools/looker/lookerrundashboard/lookerrundashboard.go +++ b/internal/tools/looker/lookerrundashboard/lookerrundashboard.go @@ -17,6 +17,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "sync" yaml "github.com/goccy/go-yaml" @@ -114,15 +115,15 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } logger.DebugContext(ctx, "params = ", params) paramsMap := params.AsMap() @@ -131,11 +132,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } dashboard, err := sdk.Dashboard(dashboard_id, "", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error getting dashboard: %w", err) + return nil, util.ProcessGeneralError(err) } data := make(map[string]any) diff --git a/internal/tools/looker/lookerrunlook/lookerrunlook.go b/internal/tools/looker/lookerrunlook/lookerrunlook.go index 2ab69a36a4..5dec4ae308 100644 --- a/internal/tools/looker/lookerrunlook/lookerrunlook.go +++ b/internal/tools/looker/lookerrunlook/lookerrunlook.go @@ -17,6 +17,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -115,15 +116,15 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } logger.DebugContext(ctx, "params = ", params) paramsMap := params.AsMap() @@ -134,12 +135,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } look, err := sdk.Look(look_id, "", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error getting look definition: %s", err) + return nil, util.ProcessGeneralError(err) } wq := v4.WriteQuery{ @@ -155,14 +156,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para resp, err := lookercommon.RunInlineQuery(ctx, sdk, &wq, "json", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making run_look request: %s", err) + return nil, util.ProcessGeneralError(err) } logger.DebugContext(ctx, "resp = ", resp) var data []any e := json.Unmarshal([]byte(resp), &data) if e != nil { - return nil, fmt.Errorf("error Unmarshaling run_look response: %s", e) + return nil, util.NewClientServerError("error Unmarshaling run_look response", http.StatusInternalServerError, e) } logger.DebugContext(ctx, "data = ", data) diff --git a/internal/tools/looker/lookerupdateprojectfile/lookerupdateprojectfile.go b/internal/tools/looker/lookerupdateprojectfile/lookerupdateprojectfile.go index 284872b2cf..35ea7328b1 100644 --- a/internal/tools/looker/lookerupdateprojectfile/lookerupdateprojectfile.go +++ b/internal/tools/looker/lookerupdateprojectfile/lookerupdateprojectfile.go @@ -16,12 +16,14 @@ package lookerupdateprojectfile import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/looker-open-source/sdk-codegen/go/rtl" @@ -111,29 +113,29 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } mapParams := params.AsMap() projectId, ok := mapParams["project_id"].(string) if !ok { - return nil, fmt.Errorf("'project_id' must be a string, got %T", mapParams["project_id"]) + return nil, util.NewAgentError(fmt.Sprintf("'project_id' must be a string, got %T", mapParams["project_id"]), nil) } filePath, ok := mapParams["file_path"].(string) if !ok { - return nil, fmt.Errorf("'file_path' must be a string, got %T", mapParams["file_path"]) + return nil, util.NewAgentError(fmt.Sprintf("'file_path' must be a string, got %T", mapParams["file_path"]), nil) } fileContent, ok := mapParams["file_content"].(string) if !ok { - return nil, fmt.Errorf("'file_content' must be a string, got %T", mapParams["file_content"]) + return nil, util.NewAgentError(fmt.Sprintf("'file_content' must be a string, got %T", mapParams["file_content"]), nil) } req := lookercommon.FileContent{ @@ -143,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para err = lookercommon.UpdateProjectFile(sdk, projectId, req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making update_project_file request: %s", err) + return nil, util.ProcessGeneralError(err) } data := make(map[string]any) diff --git a/internal/tools/looker/lookervalidateproject/lookervalidateproject.go b/internal/tools/looker/lookervalidateproject/lookervalidateproject.go index e36c3a4dd2..b769ebde9e 100644 --- a/internal/tools/looker/lookervalidateproject/lookervalidateproject.go +++ b/internal/tools/looker/lookervalidateproject/lookervalidateproject.go @@ -16,6 +16,7 @@ package lookervalidateproject import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -108,31 +109,31 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("failed to initialize Looker SDK", http.StatusInternalServerError, err) } mapParams := params.AsMap() projectId, ok := mapParams["project_id"].(string) if !ok { - return nil, fmt.Errorf("'project_id' must be a string, got %T", mapParams["project_id"]) + return nil, util.NewAgentError(fmt.Sprintf("'project_id' must be a string, got %T", mapParams["project_id"]), nil) } resp, err := sdk.ValidateProject(projectId, "", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making validate_project request: %w", err) + return nil, util.ProcessGeneralError(fmt.Errorf("error making validate_project request: %w", err)) } logger.DebugContext(ctx, "Got response of %v\n", resp) diff --git a/internal/tools/mindsdb/mindsdbexecutesql/mindsdbexecutesql.go b/internal/tools/mindsdb/mindsdbexecutesql/mindsdbexecutesql.go index bcbd5fba9a..9e0e3c644a 100644 --- a/internal/tools/mindsdb/mindsdbexecutesql/mindsdbexecutesql.go +++ b/internal/tools/mindsdb/mindsdbexecutesql/mindsdbexecutesql.go @@ -18,11 +18,13 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -55,7 +57,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -73,7 +74,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) InputSchema: inputSchema, } - // finish tool setup t := Tool{ Config: cfg, Parameters: params, @@ -83,7 +83,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return t, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -97,19 +96,23 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() - sql, ok := paramsMap["sql"].(string) + sqlStr, ok := paramsMap["sql"].(string) if !ok { - return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to get cast %s", paramsMap["sql"]), nil) } - return source.RunSQL(ctx, sql, nil) + resp, err := source.RunSQL(ctx, sqlStr, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mindsdb/mindsdbsql/mindsdbsql.go b/internal/tools/mindsdb/mindsdbsql/mindsdbsql.go index 0eb1c8eea7..782846ca6b 100644 --- a/internal/tools/mindsdb/mindsdbsql/mindsdbsql.go +++ b/internal/tools/mindsdb/mindsdbsql/mindsdbsql.go @@ -18,11 +18,13 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -58,7 +60,6 @@ type Config struct { TemplateParameters parameters.Parameters `yaml:"templateParameters"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -79,7 +80,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) InputSchema: paramMcpManifest, } - // finish tool setup t := Tool{ Config: cfg, AllParams: allParameters, @@ -89,7 +89,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return t, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -99,25 +98,29 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, newStatement, sliceParams) + resp, err := source.RunSQL(ctx, newStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go b/internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go index 24dce16680..519e79b696 100644 --- a/internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go +++ b/internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go @@ -16,10 +16,12 @@ package mongodbaggregate import ( "context" "fmt" + "net/http" "slices" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/v2/mongo" @@ -102,18 +104,22 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() pipelineString, err := parameters.PopulateTemplateWithJSON("MongoDBAggregatePipeline", t.PipelinePayload, paramsMap) if err != nil { - return nil, fmt.Errorf("error populating pipeline: %s", err) + return nil, util.NewAgentError("error populating pipeline", err) } - return source.Aggregate(ctx, pipelineString, t.Canonical, t.ReadOnly, t.Database, t.Collection) + resp, err := source.Aggregate(ctx, pipelineString, t.Canonical, t.ReadOnly, t.Database, t.Collection) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go b/internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go index 0d6f8c2be8..8f67aec41f 100644 --- a/internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go +++ b/internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go @@ -16,10 +16,12 @@ package mongodbdeletemany import ( "context" "fmt" + "net/http" "slices" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/v2/mongo" @@ -106,18 +108,22 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBDeleteManyFilter", t.FilterPayload, paramsMap) if err != nil { - return nil, fmt.Errorf("error populating filter: %s", err) + return nil, util.NewAgentError("error populating filter", err) } - return source.DeleteMany(ctx, filterString, t.Database, t.Collection) + resp, err := source.DeleteMany(ctx, filterString, t.Database, t.Collection) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go b/internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go index 416a67ffe3..55f855c7e0 100644 --- a/internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go +++ b/internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go @@ -9,17 +9,19 @@ // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and +// See the License for the language governing permissions and // limitations under the License. package mongodbdeleteone import ( "context" "fmt" + "net/http" "slices" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/v2/mongo" @@ -106,19 +108,23 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBDeleteOneFilter", t.FilterPayload, paramsMap) if err != nil { - return nil, fmt.Errorf("error populating filter: %s", err) + return nil, util.NewAgentError("error populating filter", err) } - return source.DeleteOne(ctx, filterString, t.Database, t.Collection) + resp, err := source.DeleteOne(ctx, filterString, t.Database, t.Collection) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbfind/mongodbfind.go b/internal/tools/mongodb/mongodbfind/mongodbfind.go index 9389088f0f..9da5a8b30e 100644 --- a/internal/tools/mongodb/mongodbfind/mongodbfind.go +++ b/internal/tools/mongodb/mongodbfind/mongodbfind.go @@ -16,6 +16,7 @@ package mongodbfind import ( "context" "fmt" + "net/http" "slices" "github.com/goccy/go-yaml" @@ -121,7 +122,7 @@ type Tool struct { func getOptions(ctx context.Context, sortParameters parameters.Parameters, projectPayload string, limit int64, paramsMap map[string]any) (*options.FindOptionsBuilder, error) { logger, err := util.LoggerFromContext(ctx) if err != nil { - panic(err) + return nil, err } opts := options.Find() @@ -157,22 +158,26 @@ func getOptions(ctx context.Context, sortParameters parameters.Parameters, proje return opts, nil } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBFindFilterString", t.FilterPayload, paramsMap) if err != nil { - return nil, fmt.Errorf("error populating filter: %s", err) + return nil, util.NewAgentError("error populating filter", err) } opts, err := getOptions(ctx, t.SortParams, t.ProjectPayload, t.Limit, paramsMap) if err != nil { - return nil, fmt.Errorf("error populating options: %s", err) + return nil, util.NewAgentError("error populating options", err) } - return source.Find(ctx, filterString, t.Database, t.Collection, opts) + resp, err := source.Find(ctx, filterString, t.Database, t.Collection, opts) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbfindone/mongodbfindone.go b/internal/tools/mongodb/mongodbfindone/mongodbfindone.go index 3822d1302a..f75af4328a 100644 --- a/internal/tools/mongodb/mongodbfindone/mongodbfindone.go +++ b/internal/tools/mongodb/mongodbfindone/mongodbfindone.go @@ -16,10 +16,12 @@ package mongodbfindone import ( "context" "fmt" + "net/http" "slices" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo" @@ -110,32 +112,36 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBFindOneFilterString", t.FilterPayload, paramsMap) if err != nil { - return nil, fmt.Errorf("error populating filter: %s", err) + return nil, util.NewAgentError("error populating filter", err) } opts := options.FindOne() if len(t.ProjectPayload) > 0 { result, err := parameters.PopulateTemplateWithJSON("MongoDBFindOneProjectString", t.ProjectPayload, paramsMap) if err != nil { - return nil, fmt.Errorf("error populating project payload: %s", err) + return nil, util.NewAgentError("error populating project payload", err) } var projection any err = bson.UnmarshalExtJSON([]byte(result), false, &projection) if err != nil { - return nil, fmt.Errorf("error unmarshalling projection: %s", err) + return nil, util.NewAgentError("error unmarshalling projection", err) } opts = opts.SetProjection(projection) } - return source.FindOne(ctx, filterString, t.Database, t.Collection, opts) + resp, err := source.FindOne(ctx, filterString, t.Database, t.Collection, opts) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go b/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go index 17a8020635..0de1cc8de4 100644 --- a/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go +++ b/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go @@ -15,13 +15,14 @@ package mongodbinsertmany import ( "context" - "errors" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/v2/mongo" ) @@ -100,23 +101,27 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } if len(params) == 0 { - return nil, errors.New("no input found") + return nil, util.NewAgentError("no input found", nil) } paramsMap := params.AsMap() jsonData, ok := paramsMap[paramDataKey].(string) if !ok { - return nil, errors.New("no input found") + return nil, util.NewAgentError("no input found or invalid type for data", nil) } - return source.InsertMany(ctx, jsonData, t.Canonical, t.Database, t.Collection) + resp, err := source.InsertMany(ctx, jsonData, t.Canonical, t.Database, t.Collection) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go b/internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go index d4e9f7f072..3fc9260a51 100644 --- a/internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go +++ b/internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go @@ -15,13 +15,14 @@ package mongodbinsertone import ( "context" - "errors" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/v2/mongo" ) @@ -101,20 +102,24 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } if len(params) == 0 { - return nil, errors.New("no input found") + return nil, util.NewAgentError("no input found", nil) } // use the first, assume it's a string jsonData, ok := params[0].Value.(string) if !ok { - return nil, errors.New("no input found") + return nil, util.NewAgentError("no input found or invalid type for data", nil) } - return source.InsertOne(ctx, jsonData, t.Canonical, t.Database, t.Collection) + resp, err := source.InsertOne(ctx, jsonData, t.Canonical, t.Database, t.Collection) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go b/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go index d7a7cd569b..7e34e52384 100644 --- a/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go +++ b/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go @@ -16,12 +16,14 @@ package mongodbupdatemany import ( "context" "fmt" + "net/http" "slices" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/v2/mongo" ) @@ -109,22 +111,26 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBUpdateManyFilter", t.FilterPayload, paramsMap) if err != nil { - return nil, fmt.Errorf("error populating filter: %s", err) + return nil, util.NewAgentError("error populating filter", err) } updateString, err := parameters.PopulateTemplateWithJSON("MongoDBUpdateMany", t.UpdatePayload, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to get update: %w", err) + return nil, util.NewAgentError("unable to get update", err) } - return source.UpdateMany(ctx, filterString, t.Canonical, updateString, t.Database, t.Collection, t.Upsert) + resp, err := source.UpdateMany(ctx, filterString, t.Canonical, updateString, t.Database, t.Collection, t.Upsert) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go b/internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go index 2fa99efb67..6369e08a91 100644 --- a/internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go +++ b/internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go @@ -16,12 +16,14 @@ package mongodbupdateone import ( "context" "fmt" + "net/http" "slices" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/v2/mongo" ) @@ -110,22 +112,26 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBUpdateOneFilter", t.FilterPayload, paramsMap) if err != nil { - return nil, fmt.Errorf("error populating filter: %s", err) + return nil, util.NewAgentError("error populating filter", err) } updateString, err := parameters.PopulateTemplateWithJSON("MongoDBUpdateOne", t.UpdatePayload, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to get update: %w", err) + return nil, util.NewAgentError("unable to get update", err) } - return source.UpdateOne(ctx, filterString, t.Canonical, updateString, t.Database, t.Collection, t.Upsert) + resp, err := source.UpdateOne(ctx, filterString, t.Canonical, updateString, t.Database, t.Collection, t.Upsert) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go b/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go index 3b00090823..ae3d497b84 100644 --- a/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go +++ b/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -89,25 +90,29 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() - sql, ok := paramsMap["sql"].(string) + sqlStr, ok := paramsMap["sql"].(string) if !ok { - return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to get cast %s", paramsMap["sql"]), nil) } // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } - logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql)) - return source.RunSQL(ctx, sql, nil) + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sqlStr)) + resp, err := source.RunSQL(ctx, sqlStr, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mssql/mssqllisttables/mssqllisttables.go b/internal/tools/mssql/mssqllisttables/mssqllisttables.go index 6798087768..ea462e2740 100644 --- a/internal/tools/mssql/mssqllisttables/mssqllisttables.go +++ b/internal/tools/mssql/mssqllisttables/mssqllisttables.go @@ -18,263 +18,265 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) const resourceType string = "mssql-list-tables" const listTablesStatement = ` - WITH table_info AS ( - SELECT - t.object_id AS table_oid, - s.name AS schema_name, - t.name AS table_name, - dp.name AS table_owner, -- Schema's owner principal name - CAST(ep.value AS NVARCHAR(MAX)) AS table_comment, -- Cast for JSON compatibility - CASE - WHEN EXISTS ( -- Check if the table has more than one partition for any of its indexes or heap - SELECT 1 FROM sys.partitions p - WHERE p.object_id = t.object_id AND p.partition_number > 1 - ) THEN 'PARTITIONED TABLE' - ELSE 'TABLE' - END AS object_type_detail - FROM - sys.tables t - INNER JOIN - sys.schemas s ON t.schema_id = s.schema_id - LEFT JOIN - sys.database_principals dp ON s.principal_id = dp.principal_id - LEFT JOIN - sys.extended_properties ep ON ep.major_id = t.object_id AND ep.minor_id = 0 AND ep.class = 1 AND ep.name = 'MS_Description' - WHERE - t.type = 'U' -- User tables - AND s.name NOT IN ('sys', 'INFORMATION_SCHEMA', 'guest', 'db_owner', 'db_accessadmin', 'db_backupoperator', 'db_datareader', 'db_datawriter', 'db_ddladmin', 'db_denydatareader', 'db_denydatawriter', 'db_securityadmin') - AND (@table_names IS NULL OR LTRIM(RTRIM(@table_names)) = '' OR t.name IN (SELECT LTRIM(RTRIM(value)) FROM STRING_SPLIT(@table_names, ','))) - ), - columns_info AS ( - SELECT - c.object_id AS table_oid, - c.name AS column_name, - CONCAT( - UPPER(TY.name), -- Base type name - CASE - WHEN TY.name IN ('char', 'varchar', 'nchar', 'nvarchar', 'binary', 'varbinary') THEN - CONCAT('(', IIF(c.max_length = -1, 'MAX', CAST(c.max_length / CASE WHEN TY.name IN ('nchar', 'nvarchar') THEN 2 ELSE 1 END AS VARCHAR(10))), ')') - WHEN TY.name IN ('decimal', 'numeric') THEN - CONCAT('(', c.precision, ',', c.scale, ')') - WHEN TY.name IN ('datetime2', 'datetimeoffset', 'time') THEN - CONCAT('(', c.scale, ')') - ELSE '' - END - ) AS data_type, - c.column_id AS column_ordinal_position, - IIF(c.is_nullable = 0, CAST(1 AS BIT), CAST(0 AS BIT)) AS is_not_nullable, - dc.definition AS column_default, - CAST(epc.value AS NVARCHAR(MAX)) AS column_comment - FROM - sys.columns c - JOIN - table_info ti ON c.object_id = ti.table_oid - JOIN - sys.types TY ON c.user_type_id = TY.user_type_id AND TY.is_user_defined = 0 -- Ensure we get base types - LEFT JOIN - sys.default_constraints dc ON c.object_id = dc.parent_object_id AND c.column_id = dc.parent_column_id - LEFT JOIN - sys.extended_properties epc ON epc.major_id = c.object_id AND epc.minor_id = c.column_id AND epc.class = 1 AND epc.name = 'MS_Description' - ), - constraints_info AS ( - -- Primary Keys & Unique Constraints - SELECT - kc.parent_object_id AS table_oid, - kc.name AS constraint_name, - REPLACE(kc.type_desc, '_CONSTRAINT', '') AS constraint_type, -- 'PRIMARY_KEY', 'UNIQUE' - STUFF((SELECT ', ' + col.name - FROM sys.index_columns ic - JOIN sys.columns col ON ic.object_id = col.object_id AND ic.column_id = col.column_id - WHERE ic.object_id = kc.parent_object_id AND ic.index_id = kc.unique_index_id - ORDER BY ic.key_ordinal - FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') AS constraint_columns, - NULL AS foreign_key_referenced_table, - NULL AS foreign_key_referenced_columns, - CASE kc.type - WHEN 'PK' THEN 'PRIMARY KEY (' + STUFF((SELECT ', ' + col.name FROM sys.index_columns ic JOIN sys.columns col ON ic.object_id = col.object_id AND ic.column_id = col.column_id WHERE ic.object_id = kc.parent_object_id AND ic.index_id = kc.unique_index_id ORDER BY ic.key_ordinal FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') + ')' - WHEN 'UQ' THEN 'UNIQUE (' + STUFF((SELECT ', ' + col.name FROM sys.index_columns ic JOIN sys.columns col ON ic.object_id = col.object_id AND ic.column_id = col.column_id WHERE ic.object_id = kc.parent_object_id AND ic.index_id = kc.unique_index_id ORDER BY ic.key_ordinal FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') + ')' - END AS constraint_definition - FROM sys.key_constraints kc - JOIN table_info ti ON kc.parent_object_id = ti.table_oid - UNION ALL - -- Foreign Keys - SELECT - fk.parent_object_id AS table_oid, - fk.name AS constraint_name, - 'FOREIGN KEY' AS constraint_type, - STUFF((SELECT ', ' + pc.name - FROM sys.foreign_key_columns fkc - JOIN sys.columns pc ON fkc.parent_object_id = pc.object_id AND fkc.parent_column_id = pc.column_id - WHERE fkc.constraint_object_id = fk.object_id - ORDER BY fkc.constraint_column_id - FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') AS constraint_columns, - SCHEMA_NAME(rt.schema_id) + '.' + OBJECT_NAME(fk.referenced_object_id) AS foreign_key_referenced_table, - STUFF((SELECT ', ' + rc.name - FROM sys.foreign_key_columns fkc - JOIN sys.columns rc ON fkc.referenced_object_id = rc.object_id AND fkc.referenced_column_id = rc.column_id - WHERE fkc.constraint_object_id = fk.object_id - ORDER BY fkc.constraint_column_id - FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') AS foreign_key_referenced_columns, - OBJECT_DEFINITION(fk.object_id) AS constraint_definition - FROM sys.foreign_keys fk - JOIN sys.tables rt ON fk.referenced_object_id = rt.object_id - JOIN table_info ti ON fk.parent_object_id = ti.table_oid - UNION ALL - -- Check Constraints - SELECT - cc.parent_object_id AS table_oid, - cc.name AS constraint_name, - 'CHECK' AS constraint_type, - NULL AS constraint_columns, -- Definition includes column context - NULL AS foreign_key_referenced_table, - NULL AS foreign_key_referenced_columns, - cc.definition AS constraint_definition - FROM sys.check_constraints cc - JOIN table_info ti ON cc.parent_object_id = ti.table_oid - ), - indexes_info AS ( - SELECT - i.object_id AS table_oid, - i.name AS index_name, - i.type_desc AS index_method, -- CLUSTERED, NONCLUSTERED, XML, etc. - i.is_unique, - i.is_primary_key AS is_primary, - STUFF((SELECT ', ' + c.name - FROM sys.index_columns ic - JOIN sys.columns c ON i.object_id = c.object_id AND ic.column_id = c.column_id - WHERE ic.object_id = i.object_id AND ic.index_id = i.index_id AND ic.is_included_column = 0 - ORDER BY ic.key_ordinal - FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') AS index_columns, - ( - 'COLUMNS: (' + ISNULL(STUFF((SELECT ', ' + c.name + CASE WHEN ic.is_descending_key = 1 THEN ' DESC' ELSE '' END - FROM sys.index_columns ic - JOIN sys.columns c ON i.object_id = c.object_id AND ic.column_id = c.column_id - WHERE ic.object_id = i.object_id AND ic.index_id = i.index_id AND ic.is_included_column = 0 - ORDER BY ic.key_ordinal FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, ''), 'N/A') + ')' + - ISNULL(CHAR(13)+CHAR(10) + 'INCLUDE: (' + STUFF((SELECT ', ' + c.name - FROM sys.index_columns ic - JOIN sys.columns c ON i.object_id = c.object_id AND ic.column_id = c.column_id - WHERE ic.object_id = i.object_id AND ic.index_id = i.index_id AND ic.is_included_column = 1 - ORDER BY ic.index_column_id FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') + ')', '') + - ISNULL(CHAR(13)+CHAR(10) + 'FILTER: (' + i.filter_definition + ')', '') - ) AS index_definition_details - FROM - sys.indexes i - JOIN - table_info ti ON i.object_id = ti.table_oid - WHERE i.type <> 0 -- Exclude Heaps - AND i.name IS NOT NULL -- Exclude unnamed heap indexes; named indexes (PKs are often named) are preferred. - ), - triggers_info AS ( - SELECT - tr.parent_id AS table_oid, - tr.name AS trigger_name, - OBJECT_DEFINITION(tr.object_id) AS trigger_definition, - CASE tr.is_disabled WHEN 0 THEN 'ENABLED' ELSE 'DISABLED' END AS trigger_enabled_state - FROM - sys.triggers tr - JOIN - table_info ti ON tr.parent_id = ti.table_oid - WHERE - tr.is_ms_shipped = 0 - AND tr.parent_class_desc = 'OBJECT_OR_COLUMN' -- DML Triggers on tables/views - ) - SELECT - ti.schema_name, - ti.table_name AS object_name, - CASE - WHEN @output_format = 'simple' THEN - (SELECT ti.table_name AS name FOR JSON PATH, WITHOUT_ARRAY_WRAPPER) - ELSE - ( - SELECT - ti.schema_name AS schema_name, - ti.table_name AS object_name, - ti.object_type_detail AS object_type, - ti.table_owner AS owner, - ti.table_comment AS comment, - JSON_QUERY(ISNULL(( - SELECT - ci.column_name, - ci.data_type, - ci.column_ordinal_position, - ci.is_not_nullable, - ci.column_default, - ci.column_comment - FROM columns_info ci - WHERE ci.table_oid = ti.table_oid - ORDER BY ci.column_ordinal_position - FOR JSON PATH - ), '[]')) AS columns, - JSON_QUERY(ISNULL(( - SELECT - cons.constraint_name, - cons.constraint_type, - cons.constraint_definition, - JSON_QUERY( - CASE - WHEN cons.constraint_columns IS NOT NULL AND LTRIM(RTRIM(cons.constraint_columns)) <> '' - THEN '[' + (SELECT STRING_AGG('"' + LTRIM(RTRIM(value)) + '"', ',') FROM STRING_SPLIT(cons.constraint_columns, ',')) + ']' - ELSE '[]' - END - ) AS constraint_columns, - cons.foreign_key_referenced_table, - JSON_QUERY( - CASE - WHEN cons.foreign_key_referenced_columns IS NOT NULL AND LTRIM(RTRIM(cons.foreign_key_referenced_columns)) <> '' - THEN '[' + (SELECT STRING_AGG('"' + LTRIM(RTRIM(value)) + '"', ',') FROM STRING_SPLIT(cons.foreign_key_referenced_columns, ',')) + ']' - ELSE '[]' - END - ) AS foreign_key_referenced_columns - FROM constraints_info cons - WHERE cons.table_oid = ti.table_oid - FOR JSON PATH - ), '[]')) AS constraints, - JSON_QUERY(ISNULL(( - SELECT - ii.index_name, - ii.index_definition_details AS index_definition, - ii.is_unique, - ii.is_primary, - ii.index_method, - JSON_QUERY( - CASE - WHEN ii.index_columns IS NOT NULL AND LTRIM(RTRIM(ii.index_columns)) <> '' - THEN '[' + (SELECT STRING_AGG('"' + LTRIM(RTRIM(value)) + '"', ',') FROM STRING_SPLIT(ii.index_columns, ',')) + ']' - ELSE '[]' - END - ) AS index_columns - FROM indexes_info ii - WHERE ii.table_oid = ti.table_oid - FOR JSON PATH - ), '[]')) AS indexes, - JSON_QUERY(ISNULL(( - SELECT - tri.trigger_name, - tri.trigger_definition, - tri.trigger_enabled_state - FROM triggers_info tri - WHERE tri.table_oid = ti.table_oid - FOR JSON PATH - ), '[]')) AS triggers - FOR JSON PATH, WITHOUT_ARRAY_WRAPPER -- Creates a single JSON object for this table's details - ) - END AS object_details - FROM - table_info ti - ORDER BY - ti.schema_name, ti.table_name; + WITH table_info AS ( + SELECT + t.object_id AS table_oid, + s.name AS schema_name, + t.name AS table_name, + dp.name AS table_owner, -- Schema's owner principal name + CAST(ep.value AS NVARCHAR(MAX)) AS table_comment, -- Cast for JSON compatibility + CASE + WHEN EXISTS ( -- Check if the table has more than one partition for any of its indexes or heap + SELECT 1 FROM sys.partitions p + WHERE p.object_id = t.object_id AND p.partition_number > 1 + ) THEN 'PARTITIONED TABLE' + ELSE 'TABLE' + END AS object_type_detail + FROM + sys.tables t + INNER JOIN + sys.schemas s ON t.schema_id = s.schema_id + LEFT JOIN + sys.database_principals dp ON s.principal_id = dp.principal_id + LEFT JOIN + sys.extended_properties ep ON ep.major_id = t.object_id AND ep.minor_id = 0 AND ep.class = 1 AND ep.name = 'MS_Description' + WHERE + t.type = 'U' -- User tables + AND s.name NOT IN ('sys', 'INFORMATION_SCHEMA', 'guest', 'db_owner', 'db_accessadmin', 'db_backupoperator', 'db_datareader', 'db_datawriter', 'db_ddladmin', 'db_denydatareader', 'db_denydatawriter', 'db_securityadmin') + AND (@table_names IS NULL OR LTRIM(RTRIM(@table_names)) = '' OR t.name IN (SELECT LTRIM(RTRIM(value)) FROM STRING_SPLIT(@table_names, ','))) + ), + columns_info AS ( + SELECT + c.object_id AS table_oid, + c.name AS column_name, + CONCAT( + UPPER(TY.name), -- Base type name + CASE + WHEN TY.name IN ('char', 'varchar', 'nchar', 'nvarchar', 'binary', 'varbinary') THEN + CONCAT('(', IIF(c.max_length = -1, 'MAX', CAST(c.max_length / CASE WHEN TY.name IN ('nchar', 'nvarchar') THEN 2 ELSE 1 END AS VARCHAR(10))), ')') + WHEN TY.name IN ('decimal', 'numeric') THEN + CONCAT('(', c.precision, ',', c.scale, ')') + WHEN TY.name IN ('datetime2', 'datetimeoffset', 'time') THEN + CONCAT('(', c.scale, ')') + ELSE '' + END + ) AS data_type, + c.column_id AS column_ordinal_position, + IIF(c.is_nullable = 0, CAST(1 AS BIT), CAST(0 AS BIT)) AS is_not_nullable, + dc.definition AS column_default, + CAST(epc.value AS NVARCHAR(MAX)) AS column_comment + FROM + sys.columns c + JOIN + table_info ti ON c.object_id = ti.table_oid + JOIN + sys.types TY ON c.user_type_id = TY.user_type_id AND TY.is_user_defined = 0 -- Ensure we get base types + LEFT JOIN + sys.default_constraints dc ON c.object_id = dc.parent_object_id AND c.column_id = dc.parent_column_id + LEFT JOIN + sys.extended_properties epc ON epc.major_id = c.object_id AND epc.minor_id = c.column_id AND epc.class = 1 AND epc.name = 'MS_Description' + ), + constraints_info AS ( + -- Primary Keys & Unique Constraints + SELECT + kc.parent_object_id AS table_oid, + kc.name AS constraint_name, + REPLACE(kc.type_desc, '_CONSTRAINT', '') AS constraint_type, -- 'PRIMARY_KEY', 'UNIQUE' + STUFF((SELECT ', ' + col.name + FROM sys.index_columns ic + JOIN sys.columns col ON ic.object_id = col.object_id AND ic.column_id = col.column_id + WHERE ic.object_id = kc.parent_object_id AND ic.index_id = kc.unique_index_id + ORDER BY ic.key_ordinal + FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') AS constraint_columns, + NULL AS foreign_key_referenced_table, + NULL AS foreign_key_referenced_columns, + CASE kc.type + WHEN 'PK' THEN 'PRIMARY KEY (' + STUFF((SELECT ', ' + col.name FROM sys.index_columns ic JOIN sys.columns col ON ic.object_id = col.object_id AND ic.column_id = col.column_id WHERE ic.object_id = kc.parent_object_id AND ic.index_id = kc.unique_index_id ORDER BY ic.key_ordinal FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') + ')' + WHEN 'UQ' THEN 'UNIQUE (' + STUFF((SELECT ', ' + col.name FROM sys.index_columns ic JOIN sys.columns col ON ic.object_id = col.object_id AND ic.column_id = col.column_id WHERE ic.object_id = kc.parent_object_id AND ic.index_id = kc.unique_index_id ORDER BY ic.key_ordinal FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') + ')' + END AS constraint_definition + FROM sys.key_constraints kc + JOIN table_info ti ON kc.parent_object_id = ti.table_oid + UNION ALL + -- Foreign Keys + SELECT + fk.parent_object_id AS table_oid, + fk.name AS constraint_name, + 'FOREIGN KEY' AS constraint_type, + STUFF((SELECT ', ' + pc.name + FROM sys.foreign_key_columns fkc + JOIN sys.columns pc ON fkc.parent_object_id = pc.object_id AND fkc.parent_column_id = pc.column_id + WHERE fkc.constraint_object_id = fk.object_id + ORDER BY fkc.constraint_column_id + FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') AS constraint_columns, + SCHEMA_NAME(rt.schema_id) + '.' + OBJECT_NAME(fk.referenced_object_id) AS foreign_key_referenced_table, + STUFF((SELECT ', ' + rc.name + FROM sys.foreign_key_columns fkc + JOIN sys.columns rc ON fkc.referenced_object_id = rc.object_id AND fkc.referenced_column_id = rc.column_id + WHERE fkc.constraint_object_id = fk.object_id + ORDER BY fkc.constraint_column_id + FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') AS foreign_key_referenced_columns, + OBJECT_DEFINITION(fk.object_id) AS constraint_definition + FROM sys.foreign_keys fk + JOIN sys.tables rt ON fk.referenced_object_id = rt.object_id + JOIN table_info ti ON fk.parent_object_id = ti.table_oid + UNION ALL + -- Check Constraints + SELECT + cc.parent_object_id AS table_oid, + cc.name AS constraint_name, + 'CHECK' AS constraint_type, + NULL AS constraint_columns, -- Definition includes column context + NULL AS foreign_key_referenced_table, + NULL AS foreign_key_referenced_columns, + cc.definition AS constraint_definition + FROM sys.check_constraints cc + JOIN table_info ti ON cc.parent_object_id = ti.table_oid + ), + indexes_info AS ( + SELECT + i.object_id AS table_oid, + i.name AS index_name, + i.type_desc AS index_method, -- CLUSTERED, NONCLUSTERED, XML, etc. + i.is_unique, + i.is_primary_key AS is_primary, + STUFF((SELECT ', ' + c.name + FROM sys.index_columns ic + JOIN sys.columns c ON i.object_id = c.object_id AND ic.column_id = c.column_id + WHERE ic.object_id = i.object_id AND ic.index_id = i.index_id AND ic.is_included_column = 0 + ORDER BY ic.key_ordinal + FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') AS index_columns, + ( + 'COLUMNS: (' + ISNULL(STUFF((SELECT ', ' + c.name + CASE WHEN ic.is_descending_key = 1 THEN ' DESC' ELSE '' END + FROM sys.index_columns ic + JOIN sys.columns c ON i.object_id = c.object_id AND ic.column_id = c.column_id + WHERE ic.object_id = i.object_id AND ic.index_id = i.index_id AND ic.is_included_column = 0 + ORDER BY ic.key_ordinal FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, ''), 'N/A') + ')' + + ISNULL(CHAR(13)+CHAR(10) + 'INCLUDE: (' + STUFF((SELECT ', ' + c.name + FROM sys.index_columns ic + JOIN sys.columns c ON i.object_id = c.object_id AND ic.column_id = c.column_id + WHERE ic.object_id = i.object_id AND ic.index_id = i.index_id AND ic.is_included_column = 1 + ORDER BY ic.index_column_id FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') + ')', '') + + ISNULL(CHAR(13)+CHAR(10) + 'FILTER: (' + i.filter_definition + ')', '') + ) AS index_definition_details + FROM + sys.indexes i + JOIN + table_info ti ON i.object_id = ti.table_oid + WHERE i.type <> 0 -- Exclude Heaps + AND i.name IS NOT NULL -- Exclude unnamed heap indexes; named indexes (PKs are often named) are preferred. + ), + triggers_info AS ( + SELECT + tr.parent_id AS table_oid, + tr.name AS trigger_name, + OBJECT_DEFINITION(tr.object_id) AS trigger_definition, + CASE tr.is_disabled WHEN 0 THEN 'ENABLED' ELSE 'DISABLED' END AS trigger_enabled_state + FROM + sys.triggers tr + JOIN + table_info ti ON tr.parent_id = ti.table_oid + WHERE + tr.is_ms_shipped = 0 + AND tr.parent_class_desc = 'OBJECT_OR_COLUMN' -- DML Triggers on tables/views + ) + SELECT + ti.schema_name, + ti.table_name AS object_name, + CASE + WHEN @output_format = 'simple' THEN + (SELECT ti.table_name AS name FOR JSON PATH, WITHOUT_ARRAY_WRAPPER) + ELSE + ( + SELECT + ti.schema_name AS schema_name, + ti.table_name AS object_name, + ti.object_type_detail AS object_type, + ti.table_owner AS owner, + ti.table_comment AS comment, + JSON_QUERY(ISNULL(( + SELECT + ci.column_name, + ci.data_type, + ci.column_ordinal_position, + ci.is_not_nullable, + ci.column_default, + ci.column_comment + FROM columns_info ci + WHERE ci.table_oid = ti.table_oid + ORDER BY ci.column_ordinal_position + FOR JSON PATH + ), '[]')) AS columns, + JSON_QUERY(ISNULL(( + SELECT + cons.constraint_name, + cons.constraint_type, + cons.constraint_definition, + JSON_QUERY( + CASE + WHEN cons.constraint_columns IS NOT NULL AND LTRIM(RTRIM(cons.constraint_columns)) <> '' + THEN '[' + (SELECT STRING_AGG('"' + LTRIM(RTRIM(value)) + '"', ',') FROM STRING_SPLIT(cons.constraint_columns, ',')) + ']' + ELSE '[]' + END + ) AS constraint_columns, + cons.foreign_key_referenced_table, + JSON_QUERY( + CASE + WHEN cons.foreign_key_referenced_columns IS NOT NULL AND LTRIM(RTRIM(cons.foreign_key_referenced_columns)) <> '' + THEN '[' + (SELECT STRING_AGG('"' + LTRIM(RTRIM(value)) + '"', ',') FROM STRING_SPLIT(cons.foreign_key_referenced_columns, ',')) + ']' + ELSE '[]' + END + ) AS foreign_key_referenced_columns + FROM constraints_info cons + WHERE cons.table_oid = ti.table_oid + FOR JSON PATH + ), '[]')) AS constraints, + JSON_QUERY(ISNULL(( + SELECT + ii.index_name, + ii.index_definition_details AS index_definition, + ii.is_unique, + ii.is_primary, + ii.index_method, + JSON_QUERY( + CASE + WHEN ii.index_columns IS NOT NULL AND LTRIM(RTRIM(ii.index_columns)) <> '' + THEN '[' + (SELECT STRING_AGG('"' + LTRIM(RTRIM(value)) + '"', ',') FROM STRING_SPLIT(ii.index_columns, ',')) + ']' + ELSE '[]' + END + ) AS index_columns + FROM indexes_info ii + WHERE ii.table_oid = ti.table_oid + FOR JSON PATH + ), '[]')) AS indexes, + JSON_QUERY(ISNULL(( + SELECT + tri.trigger_name, + tri.trigger_definition, + tri.trigger_enabled_state + FROM triggers_info tri + WHERE tri.table_oid = ti.table_oid + FOR JSON PATH + ), '[]')) AS triggers + FOR JSON PATH, WITHOUT_ARRAY_WRAPPER -- Creates a single JSON object for this table's details + ) + END AS object_details + FROM + table_info ti + ORDER BY + ti.schema_name, ti.table_name; ` func init() { @@ -339,17 +341,17 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() outputFormat, _ := paramsMap["output_format"].(string) if outputFormat != "simple" && outputFormat != "detailed" { - return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat) + return nil, util.NewAgentError(fmt.Sprintf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat), nil) } namedArgs := []any{ @@ -358,14 +360,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := source.RunSQL(ctx, listTablesStatement, namedArgs) if err != nil { - return nil, err + return nil, util.ProcessGeneralError(err) } // if there's no results, return empty list instead of null resSlice, ok := resp.([]any) if !ok || len(resSlice) == 0 { return []any{}, nil } - return resp, err + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mssql/mssqlsql/mssqlsql.go b/internal/tools/mssql/mssqlsql/mssqlsql.go index 57b67ec9ac..4e5878e89f 100644 --- a/internal/tools/mssql/mssqlsql/mssqlsql.go +++ b/internal/tools/mssql/mssqlsql/mssqlsql.go @@ -18,12 +18,14 @@ import ( "context" "database/sql" "fmt" + "net/http" "strings" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -94,21 +96,21 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } namedArgs := make([]any, 0, len(newParams)) @@ -123,7 +125,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para namedArgs = append(namedArgs, value) } } - return source.RunSQL(ctx, newStatement, namedArgs) + resp, err := source.RunSQL(ctx, newStatement, namedArgs) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go b/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go index fb4c6a0a97..4363ba2ed7 100644 --- a/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go +++ b/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -89,25 +90,29 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() - sql, ok := paramsMap["sql"].(string) + sqlStr, ok := paramsMap["sql"].(string) if !ok { - return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to get cast %s", paramsMap["sql"]), nil) } // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } - logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql)) - return source.RunSQL(ctx, sql, nil) + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sqlStr)) + resp, err := source.RunSQL(ctx, sqlStr, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go b/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go index d152cc2394..b2e6008af2 100644 --- a/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go +++ b/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go @@ -19,6 +19,7 @@ import ( "database/sql" "encoding/json" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -91,46 +92,46 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() - sql, ok := paramsMap["sql_statement"].(string) + sqlStr, ok := paramsMap["sql_statement"].(string) if !ok { - return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql_statement"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to get cast %s", paramsMap["sql_statement"]), nil) } // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } - logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql)) + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sqlStr)) - query := fmt.Sprintf("EXPLAIN FORMAT=JSON %s", sql) + query := fmt.Sprintf("EXPLAIN FORMAT=JSON %s", sqlStr) result, err := source.RunSQL(ctx, query, nil) if err != nil { - return nil, err + return nil, util.ProcessGeneralError(err) } // extract and return only the query plan object resSlice, ok := result.([]any) if !ok || len(resSlice) == 0 { - return nil, fmt.Errorf("no query plan returned") + return nil, util.NewClientServerError("no query plan returned", http.StatusInternalServerError, nil) } row, ok := resSlice[0].(orderedmap.Row) if !ok || len(row.Columns) == 0 { - return nil, fmt.Errorf("no query plan returned in row") + return nil, util.NewClientServerError("no query plan returned in row", http.StatusInternalServerError, nil) } plan, ok := row.Columns[0].Value.(string) if !ok { - return nil, fmt.Errorf("unable to convert plan object to string") + return nil, util.NewClientServerError("unable to convert plan object to string", http.StatusInternalServerError, nil) } var out map[string]any if err := json.Unmarshal([]byte(plan), &out); err != nil { - return nil, fmt.Errorf("failed to unmarshal query plan json: %w", err) + return nil, util.NewClientServerError("failed to unmarshal query plan json", http.StatusInternalServerError, err) } return out, nil } diff --git a/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go b/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go index d08b57f0ce..3437657da6 100644 --- a/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go +++ b/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -32,65 +33,65 @@ import ( const resourceType string = "mysql-list-active-queries" const listActiveQueriesStatementMySQL = ` - SELECT - p.id AS processlist_id, - substring(IFNULL(p.info, t.trx_query), 1, 100) AS query, - t.trx_started AS trx_started, - (UNIX_TIMESTAMP(UTC_TIMESTAMP()) - UNIX_TIMESTAMP(t.trx_started)) AS trx_duration_seconds, - (UNIX_TIMESTAMP(UTC_TIMESTAMP()) - UNIX_TIMESTAMP(t.trx_wait_started)) AS trx_wait_duration_seconds, - p.time AS query_time, - t.trx_state AS trx_state, - p.state AS process_state, - IF(p.host IS NULL OR p.host = '', p.user, concat(p.user, '@', SUBSTRING_INDEX(p.host, ':', 1))) AS user, - t.trx_rows_locked AS trx_rows_locked, - t.trx_rows_modified AS trx_rows_modified, - p.db AS db - FROM - information_schema.processlist p - LEFT OUTER JOIN - information_schema.innodb_trx t - ON p.id = t.trx_mysql_thread_id - WHERE - (? IS NULL OR p.time >= ?) - AND p.id != CONNECTION_ID() - AND Command NOT IN ('Binlog Dump', 'Binlog Dump GTID', 'Connect', 'Connect Out', 'Register Slave') - AND User NOT IN ('system user', 'event_scheduler') - AND (t.trx_id is NOT NULL OR command != 'Sleep') - ORDER BY - t.trx_started - LIMIT ?; + SELECT + p.id AS processlist_id, + substring(IFNULL(p.info, t.trx_query), 1, 100) AS query, + t.trx_started AS trx_started, + (UNIX_TIMESTAMP(UTC_TIMESTAMP()) - UNIX_TIMESTAMP(t.trx_started)) AS trx_duration_seconds, + (UNIX_TIMESTAMP(UTC_TIMESTAMP()) - UNIX_TIMESTAMP(t.trx_wait_started)) AS trx_wait_duration_seconds, + p.time AS query_time, + t.trx_state AS trx_state, + p.state AS process_state, + IF(p.host IS NULL OR p.host = '', p.user, concat(p.user, '@', SUBSTRING_INDEX(p.host, ':', 1))) AS user, + t.trx_rows_locked AS trx_rows_locked, + t.trx_rows_modified AS trx_rows_modified, + p.db AS db + FROM + information_schema.processlist p + LEFT OUTER JOIN + information_schema.innodb_trx t + ON p.id = t.trx_mysql_thread_id + WHERE + (? IS NULL OR p.time >= ?) + AND p.id != CONNECTION_ID() + AND Command NOT IN ('Binlog Dump', 'Binlog Dump GTID', 'Connect', 'Connect Out', 'Register Slave') + AND User NOT IN ('system user', 'event_scheduler') + AND (t.trx_id is NOT NULL OR command != 'Sleep') + ORDER BY + t.trx_started + LIMIT ?; ` const listActiveQueriesStatementCloudSQLMySQL = ` - SELECT - p.id AS processlist_id, - substring(IFNULL(p.info, t.trx_query), 1, 100) AS query, - t.trx_started AS trx_started, - (UNIX_TIMESTAMP(UTC_TIMESTAMP()) - UNIX_TIMESTAMP(t.trx_started)) AS trx_duration_seconds, - (UNIX_TIMESTAMP(UTC_TIMESTAMP()) - UNIX_TIMESTAMP(t.trx_wait_started)) AS trx_wait_duration_seconds, - p.time AS query_time, - t.trx_state AS trx_state, - p.state AS process_state, - IF(p.host IS NULL OR p.host = '', p.user, concat(p.user, '@', SUBSTRING_INDEX(p.host, ':', 1))) AS user, - t.trx_rows_locked AS trx_rows_locked, - t.trx_rows_modified AS trx_rows_modified, - p.db AS db - FROM - information_schema.processlist p - LEFT OUTER JOIN - information_schema.innodb_trx t - ON p.id = t.trx_mysql_thread_id - WHERE - (? IS NULL OR p.time >= ?) - AND p.id != CONNECTION_ID() - AND SUBSTRING_INDEX(IFNULL(p.host,''), ':', 1) NOT IN ('localhost', '127.0.0.1') - AND IFNULL(p.host,'') NOT LIKE '::1%' - AND Command NOT IN ('Binlog Dump', 'Binlog Dump GTID', 'Connect', 'Connect Out', 'Register Slave') - AND User NOT IN ('system user', 'event_scheduler') - AND (t.trx_id is NOT NULL OR command != 'sleep') - ORDER BY - t.trx_started - LIMIT ?; + SELECT + p.id AS processlist_id, + substring(IFNULL(p.info, t.trx_query), 1, 100) AS query, + t.trx_started AS trx_started, + (UNIX_TIMESTAMP(UTC_TIMESTAMP()) - UNIX_TIMESTAMP(t.trx_started)) AS trx_duration_seconds, + (UNIX_TIMESTAMP(UTC_TIMESTAMP()) - UNIX_TIMESTAMP(t.trx_wait_started)) AS trx_wait_duration_seconds, + p.time AS query_time, + t.trx_state AS trx_state, + p.state AS process_state, + IF(p.host IS NULL OR p.host = '', p.user, concat(p.user, '@', SUBSTRING_INDEX(p.host, ':', 1))) AS user, + t.trx_rows_locked AS trx_rows_locked, + t.trx_rows_modified AS trx_rows_modified, + p.db AS db + FROM + information_schema.processlist p + LEFT OUTER JOIN + information_schema.innodb_trx t + ON p.id = t.trx_mysql_thread_id + WHERE + (? IS NULL OR p.time >= ?) + AND p.id != CONNECTION_ID() + AND SUBSTRING_INDEX(IFNULL(p.host,''), ':', 1) NOT IN ('localhost', '127.0.0.1') + AND IFNULL(p.host,'') NOT LIKE '::1%' + AND Command NOT IN ('Binlog Dump', 'Binlog Dump GTID', 'Connect', 'Connect Out', 'Register Slave') + AND User NOT IN ('system user', 'event_scheduler') + AND (t.trx_id is NOT NULL OR command != 'sleep') + ORDER BY + t.trx_started + LIMIT ?; ` func init() { @@ -177,30 +178,34 @@ type Tool struct { statement string } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() duration, ok := paramsMap["min_duration_secs"].(int) if !ok { - return nil, fmt.Errorf("invalid 'min_duration_secs' parameter; expected an integer") + return nil, util.NewAgentError("invalid 'min_duration_secs' parameter; expected an integer", nil) } limit, ok := paramsMap["limit"].(int) if !ok { - return nil, fmt.Errorf("invalid 'limit' parameter; expected an integer") + return nil, util.NewAgentError("invalid 'limit' parameter; expected an integer", nil) } // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, t.statement)) - return source.RunSQL(ctx, t.statement, []any{duration, duration, limit}) + resp, err := source.RunSQL(ctx, t.statement, []any{duration, duration, limit}) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go b/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go index a6954284e5..4277a6379d 100644 --- a/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go +++ b/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -30,25 +31,25 @@ import ( const resourceType string = "mysql-list-table-fragmentation" const listTableFragmentationStatement = ` - SELECT - table_schema, - table_name, - data_length AS data_size, - index_length AS index_size, - data_free AS data_free, - ROUND((data_free / (data_length + index_length)) * 100, 2) AS fragmentation_percentage - FROM - information_schema.tables - WHERE - table_schema NOT IN ('sys', 'performance_schema', 'mysql', 'information_schema') - AND (COALESCE(?, '') = '' OR table_schema = ?) - AND (COALESCE(?, '') = '' OR table_name = ?) - AND data_free >= ? - ORDER BY - fragmentation_percentage DESC, - table_schema, - table_name - LIMIT ?; + SELECT + table_schema, + table_name, + data_length AS data_size, + index_length AS index_size, + data_free AS data_free, + ROUND((data_free / (data_length + index_length)) * 100, 2) AS fragmentation_percentage + FROM + information_schema.tables + WHERE + table_schema NOT IN ('sys', 'performance_schema', 'mysql', 'information_schema') + AND (COALESCE(?, '') = '' OR table_schema = ?) + AND (COALESCE(?, '') = '' OR table_name = ?) + AND data_free >= ? + ORDER BY + fragmentation_percentage DESC, + table_schema, + table_name + LIMIT ?; ` func init() { @@ -114,39 +115,43 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() table_schema, ok := paramsMap["table_schema"].(string) if !ok { - return nil, fmt.Errorf("invalid 'table_schema' parameter; expected a string") + return nil, util.NewAgentError("invalid 'table_schema' parameter; expected a string", nil) } table_name, ok := paramsMap["table_name"].(string) if !ok { - return nil, fmt.Errorf("invalid 'table_name' parameter; expected a string") + return nil, util.NewAgentError("invalid 'table_name' parameter; expected a string", nil) } data_free_threshold_bytes, ok := paramsMap["data_free_threshold_bytes"].(int) if !ok { - return nil, fmt.Errorf("invalid 'data_free_threshold_bytes' parameter; expected an integer") + return nil, util.NewAgentError("invalid 'data_free_threshold_bytes' parameter; expected an integer", nil) } limit, ok := paramsMap["limit"].(int) if !ok { - return nil, fmt.Errorf("invalid 'limit' parameter; expected an integer") + return nil, util.NewAgentError("invalid 'limit' parameter; expected an integer", nil) } // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, listTableFragmentationStatement)) sliceParams := []any{table_schema, table_schema, table_name, table_name, data_free_threshold_bytes, limit} - return source.RunSQL(ctx, listTableFragmentationStatement, sliceParams) + resp, err := source.RunSQL(ctx, listTableFragmentationStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqllisttables/mysqllisttables.go b/internal/tools/mysql/mysqllisttables/mysqllisttables.go index 9f8879917a..cfca0f87c6 100644 --- a/internal/tools/mysql/mysqllisttables/mysqllisttables.go +++ b/internal/tools/mysql/mysqllisttables/mysqllisttables.go @@ -18,11 +18,13 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -244,32 +246,32 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() tableNames, ok := paramsMap["table_names"].(string) if !ok { - return nil, fmt.Errorf("invalid '%s' parameter; expected a string", tableNames) + return nil, util.NewAgentError(fmt.Sprintf("invalid '%s' parameter; expected a string", tableNames), nil) } outputFormat, _ := paramsMap["output_format"].(string) if outputFormat != "simple" && outputFormat != "detailed" { - return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat) + return nil, util.NewAgentError(fmt.Sprintf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat), nil) } resp, err := source.RunSQL(ctx, listTablesStatement, []any{tableNames, outputFormat}) if err != nil { - return nil, err + return nil, util.ProcessGeneralError(err) } // if there's no results, return empty list instead of null resSlice, ok := resp.([]any) if !ok || len(resSlice) == 0 { return []any{}, nil } - return resp, err + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go b/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go index 5cdeeae61f..50954e6f83 100644 --- a/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go +++ b/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -30,26 +31,26 @@ import ( const resourceType string = "mysql-list-tables-missing-unique-indexes" const listTablesMissingUniqueIndexesStatement = ` - SELECT - tab.table_schema AS table_schema, - tab.table_name AS table_name - FROM - information_schema.tables tab - LEFT JOIN - information_schema.table_constraints tco - ON - tab.table_schema = tco.table_schema - AND tab.table_name = tco.table_name - AND tco.constraint_type IN ('PRIMARY KEY', 'UNIQUE') - WHERE - tco.constraint_type IS NULL - AND tab.table_schema NOT IN('mysql', 'information_schema', 'performance_schema', 'sys') - AND tab.table_type = 'BASE TABLE' - AND (COALESCE(?, '') = '' OR tab.table_schema = ?) - ORDER BY - tab.table_schema, - tab.table_name - LIMIT ?; + SELECT + tab.table_schema AS table_schema, + tab.table_name AS table_name + FROM + information_schema.tables tab + LEFT JOIN + information_schema.table_constraints tco + ON + tab.table_schema = tco.table_schema + AND tab.table_name = tco.table_name + AND tco.constraint_type IN ('PRIMARY KEY', 'UNIQUE') + WHERE + tco.constraint_type IS NULL + AND tab.table_schema NOT IN('mysql', 'information_schema', 'performance_schema', 'sys') + AND tab.table_type = 'BASE TABLE' + AND (COALESCE(?, '') = '' OR tab.table_schema = ?) + ORDER BY + tab.table_schema, + tab.table_name + LIMIT ?; ` func init() { @@ -113,30 +114,34 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() table_schema, ok := paramsMap["table_schema"].(string) if !ok { - return nil, fmt.Errorf("invalid 'table_schema' parameter; expected a string") + return nil, util.NewAgentError("invalid 'table_schema' parameter; expected a string", nil) } limit, ok := paramsMap["limit"].(int) if !ok { - return nil, fmt.Errorf("invalid 'limit' parameter; expected an integer") + return nil, util.NewAgentError("invalid 'limit' parameter; expected an integer", nil) } // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, listTablesMissingUniqueIndexesStatement)) - return source.RunSQL(ctx, listTablesMissingUniqueIndexesStatement, []any{table_schema, table_schema, limit}) + resp, err := source.RunSQL(ctx, listTablesMissingUniqueIndexesStatement, []any{table_schema, table_schema, limit}) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqlsql/mysqlsql.go b/internal/tools/mysql/mysqlsql/mysqlsql.go index 79c0adbaf5..e65e562128 100644 --- a/internal/tools/mysql/mysqlsql/mysqlsql.go +++ b/internal/tools/mysql/mysqlsql/mysqlsql.go @@ -18,11 +18,13 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -93,25 +95,29 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, newStatement, sliceParams) + resp, err := source.RunSQL(ctx, newStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/neo4j/neo4jcypher/neo4jcypher.go b/internal/tools/neo4j/neo4jcypher/neo4jcypher.go index 3c9459ff63..fc4cb89f1b 100644 --- a/internal/tools/neo4j/neo4jcypher/neo4jcypher.go +++ b/internal/tools/neo4j/neo4jcypher/neo4jcypher.go @@ -17,12 +17,14 @@ package neo4jcypher import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -85,14 +87,18 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() - return source.RunQuery(ctx, t.Statement, paramsMap, false, false) + resp, err := source.RunQuery(ctx, t.Statement, paramsMap, false, false) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go b/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go index ef32d1c6e7..2ea2fd9681 100644 --- a/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go +++ b/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go @@ -17,11 +17,13 @@ package neo4jexecutecypher import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -94,28 +96,32 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() cypherStr, ok := paramsMap["cypher"].(string) if !ok { - return nil, fmt.Errorf("unable to cast cypher parameter %s", paramsMap["cypher"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to cast cypher parameter %s", paramsMap["cypher"]), nil) } if cypherStr == "" { - return nil, fmt.Errorf("parameter 'cypher' must be a non-empty string") + return nil, util.NewAgentError("parameter 'cypher' must be a non-empty string", nil) } dryRun, ok := paramsMap["dry_run"].(bool) if !ok { - return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to cast dry_run parameter %s", paramsMap["dry_run"]), nil) } - return source.RunQuery(ctx, cypherStr, nil, t.ReadOnly, dryRun) + resp, err := source.RunQuery(ctx, cypherStr, nil, t.ReadOnly, dryRun) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/neo4j/neo4jschema/neo4jschema.go b/internal/tools/neo4j/neo4jschema/neo4jschema.go index 441a8eaa72..9f217a2502 100644 --- a/internal/tools/neo4j/neo4jschema/neo4jschema.go +++ b/internal/tools/neo4j/neo4jschema/neo4jschema.go @@ -17,6 +17,7 @@ package neo4jschema import ( "context" "fmt" + "net/http" "sync" "time" @@ -27,6 +28,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/cache" "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/helpers" "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/types" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/neo4j/neo4j-go-driver/v5/neo4j" ) @@ -113,10 +115,10 @@ type Tool struct { // Invoke executes the tool's main logic: fetching the Neo4j schema. // It first checks the cache for a valid schema before extracting it from the database. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } // Check if a valid schema is already in the cache. @@ -129,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // If not cached, extract the schema from the database. schema, err := t.extractSchema(ctx, source) if err != nil { - return nil, fmt.Errorf("failed to extract database schema: %w", err) + return nil, util.ProcessGeneralError(err) } // Cache the newly extracted schema for future use. @@ -372,14 +374,14 @@ func (t Tool) GetAPOCSchema(ctx context.Context, source compatibleSource) ([]typ name: "apoc-relationships", fn: func(session neo4j.SessionWithContext) error { query := ` - MATCH (startNode)-[rel]->(endNode) - WITH - labels(startNode)[0] AS startNode, - type(rel) AS relType, - apoc.meta.cypher.types(rel) AS relProperties, - labels(endNode)[0] AS endNode, - count(*) AS count - RETURN relType, startNode, endNode, relProperties, count` + MATCH (startNode)-[rel]->(endNode) + WITH + labels(startNode)[0] AS startNode, + type(rel) AS relType, + apoc.meta.cypher.types(rel) AS relProperties, + labels(endNode)[0] AS endNode, + count(*) AS count + RETURN relType, startNode, endNode, relProperties, count` result, err := session.Run(ctx, query, nil) if err != nil { return fmt.Errorf("failed to extract relationships: %w", err) @@ -520,10 +522,10 @@ func (t Tool) GetSchemaWithoutAPOC(ctx context.Context, source compatibleSource, name: "relationship-schema", fn: func(session neo4j.SessionWithContext) error { relQuery := ` - MATCH (start)-[r]->(end) - WITH type(r) AS relType, labels(start) AS startLabels, labels(end) AS endLabels, count(*) AS count - RETURN relType, CASE WHEN size(startLabels) > 0 THEN startLabels[0] ELSE null END AS startLabel, CASE WHEN size(endLabels) > 0 THEN endLabels[0] ELSE null END AS endLabel, sum(count) AS totalCount - ORDER BY totalCount DESC` + MATCH (start)-[r]->(end) + WITH type(r) AS relType, labels(start) AS startLabels, labels(end) AS endLabels, count(*) AS count + RETURN relType, CASE WHEN size(startLabels) > 0 THEN startLabels[0] ELSE null END AS startLabel, CASE WHEN size(endLabels) > 0 THEN endLabels[0] ELSE null END AS endLabel, sum(count) AS totalCount + ORDER BY totalCount DESC` relResult, err := session.Run(ctx, relQuery, nil) if err != nil { return fmt.Errorf("relationship count query failed: %w", err) diff --git a/internal/tools/oceanbase/oceanbaseexecutesql/oceanbaseexecutesql.go b/internal/tools/oceanbase/oceanbaseexecutesql/oceanbaseexecutesql.go index 1987f24d45..173199daea 100644 --- a/internal/tools/oceanbase/oceanbaseexecutesql/oceanbaseexecutesql.go +++ b/internal/tools/oceanbase/oceanbaseexecutesql/oceanbaseexecutesql.go @@ -18,11 +18,13 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -89,18 +91,22 @@ type Tool struct { } // Invoke executes the SQL statement provided in the parameters. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } sliceParams := params.AsSlice() sqlStr, ok := sliceParams[0].(string) if !ok { - return nil, fmt.Errorf("unable to get cast %s", sliceParams[0]) + return nil, util.NewAgentError(fmt.Sprintf("unable to get cast %s", sliceParams[0]), nil) } - return source.RunSQL(ctx, sqlStr, nil) + resp, err := source.RunSQL(ctx, sqlStr, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/oceanbase/oceanbasesql/oceanbasesql.go b/internal/tools/oceanbase/oceanbasesql/oceanbasesql.go index 0b8a7421d3..ddcc83fbc5 100644 --- a/internal/tools/oceanbase/oceanbasesql/oceanbasesql.go +++ b/internal/tools/oceanbase/oceanbasesql/oceanbasesql.go @@ -18,11 +18,13 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -94,24 +96,28 @@ type Tool struct { } // Invoke executes the SQL statement with the provided parameters. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, newStatement, sliceParams) + resp, err := source.RunSQL(ctx, newStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go b/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go index c91b3bcc06..1f7a047681 100644 --- a/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go +++ b/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go @@ -6,6 +6,7 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -77,25 +78,29 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() sqlParam, ok := paramsMap["sql"].(string) if !ok { - return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to get cast %s", paramsMap["sql"]), nil) } // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } logger.DebugContext(ctx, "executing `%s` tool query: %s", resourceType, sqlParam) - return source.RunSQL(ctx, sqlParam, nil) + resp, err := source.RunSQL(ctx, sqlParam, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/oracle/oraclesql/oraclesql.go b/internal/tools/oracle/oraclesql/oraclesql.go index 347b18d41b..84041ce6b1 100644 --- a/internal/tools/oracle/oraclesql/oraclesql.go +++ b/internal/tools/oracle/oraclesql/oraclesql.go @@ -6,11 +6,13 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -81,21 +83,21 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() @@ -103,7 +105,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para fmt.Printf("[%d]=%T ", i, p) } fmt.Printf("\n") - return source.RunSQL(ctx, newStatement, sliceParams) + resp, err := source.RunSQL(ctx, newStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgresdatabaseoverview/postgresdatabaseoverview.go b/internal/tools/postgres/postgresdatabaseoverview/postgresdatabaseoverview.go index 4d48cbc6cb..e621a142ff 100644 --- a/internal/tools/postgres/postgresdatabaseoverview/postgresdatabaseoverview.go +++ b/internal/tools/postgres/postgresdatabaseoverview/postgresdatabaseoverview.go @@ -17,11 +17,13 @@ package postgresdatabaseoverview import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,13 +31,13 @@ import ( const resourceType string = "postgres-database-overview" const databaseOverviewStatement = ` - SELECT + SELECT current_setting('server_version') AS pg_version, pg_is_in_recovery() AS is_replica, (now() - pg_postmaster_start_time())::TEXT AS uptime, current_setting('max_connections')::int AS max_connections, - (SELECT count(*) FROM pg_stat_activity) AS current_connections, - (SELECT count(*) FROM pg_stat_activity WHERE state = 'active') AS active_connections, + (SELECT count(*) FROM pg_stat_activity) AS current_connections, + (SELECT count(*) FROM pg_stat_activity WHERE state = 'active') AS active_connections, round( (100.0 * (SELECT count(*) FROM pg_stat_activity) / current_setting('max_connections')::int), 2 @@ -57,7 +59,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T } type compatibleSource interface { - PostgresPool() *pgxpool.Pool // keep this so that sources are postgres compatible + PostgresPool() *pgxpool.Pool RunSQL(context.Context, string, []any) (any, error) } @@ -69,7 +71,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -83,7 +84,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -96,7 +96,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -110,20 +109,24 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, databaseOverviewStatement, sliceParams) + resp, err := source.RunSQL(ctx, databaseOverviewStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go b/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go index 57e0c8fce4..7b81f9bfce 100644 --- a/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go +++ b/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go @@ -17,6 +17,7 @@ package postgresexecutesql import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -56,7 +57,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -69,7 +69,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) - // finish tool setup t := Tool{ Config: cfg, Parameters: params, @@ -79,7 +78,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return t, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -89,25 +87,28 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { - return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to get cast %s", paramsMap["sql"]), nil) } - // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql)) - return source.RunSQL(ctx, sql, nil) + resp, err := source.RunSQL(ctx, sql, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgresgetcolumncardinality/postgresgetcolumncardinality.go b/internal/tools/postgres/postgresgetcolumncardinality/postgresgetcolumncardinality.go index 81cc92673e..b4358f439f 100644 --- a/internal/tools/postgres/postgresgetcolumncardinality/postgresgetcolumncardinality.go +++ b/internal/tools/postgres/postgresgetcolumncardinality/postgresgetcolumncardinality.go @@ -17,11 +17,13 @@ package postgresgetcolumncardinality import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,7 +31,7 @@ import ( const resourceType string = "postgres-get-column-cardinality" const getColumnCardinality = ` - SELECT + SELECT s.attname AS column_name, ROUND( CASE @@ -74,7 +76,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -95,7 +96,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -108,7 +108,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -122,20 +121,24 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, getColumnCardinality, sliceParams) + resp, err := source.RunSQL(ctx, getColumnCardinality, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistactivequeries/postgreslistactivequeries.go b/internal/tools/postgres/postgreslistactivequeries/postgreslistactivequeries.go index a7b1f7587d..ab4b36c3a3 100644 --- a/internal/tools/postgres/postgreslistactivequeries/postgreslistactivequeries.go +++ b/internal/tools/postgres/postgreslistactivequeries/postgreslistactivequeries.go @@ -17,11 +17,13 @@ package postgreslistactivequeries import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,26 +31,26 @@ import ( const resourceType string = "postgres-list-active-queries" const listActiveQueriesStatement = ` - SELECT - pid, - usename AS user, - datname, - application_name, - client_addr, - state, - wait_event_type, - wait_event, - backend_start, - xact_start, - query_start, - now() - query_start AS query_duration, - query - FROM pg_stat_activity - WHERE state = 'active' - AND ($1::INTERVAL IS NULL OR now() - query_start >= $1::INTERVAL) - AND ($2::text IS NULL OR application_name NOT IN (SELECT trim(app) FROM unnest(string_to_array($2, ',')) AS app)) - ORDER BY query_duration DESC - LIMIT COALESCE($3::int, 50); + SELECT + pid, + usename AS user, + datname, + application_name, + client_addr, + state, + wait_event_type, + wait_event, + backend_start, + xact_start, + query_start, + now() - query_start AS query_duration, + query + FROM pg_stat_activity + WHERE state = 'active' + AND ($1::INTERVAL IS NULL OR now() - query_start >= $1::INTERVAL) + AND ($2::text IS NULL OR application_name NOT IN (SELECT trim(app) FROM unnest(string_to_array($2, ',')) AS app)) + ORDER BY query_duration DESC + LIMIT COALESCE($3::int, 50); ` func init() { @@ -78,7 +80,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -94,8 +95,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) paramManifest := allParameters.Manifest() mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup - t := Tool{ + return Tool{ Config: cfg, allParams: allParameters, manifest: tools.Manifest{ @@ -104,11 +104,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) AuthRequired: cfg.AuthRequired, }, mcpManifest: mcpManifest, - } - return t, nil + }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -118,21 +116,25 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listActiveQueriesStatement, sliceParams) + resp, err := source.RunSQL(ctx, listActiveQueriesStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistavailableextensions/postgreslistavailableextensions.go b/internal/tools/postgres/postgreslistavailableextensions/postgreslistavailableextensions.go index 489df27583..6ecf06509d 100644 --- a/internal/tools/postgres/postgreslistavailableextensions/postgreslistavailableextensions.go +++ b/internal/tools/postgres/postgreslistavailableextensions/postgreslistavailableextensions.go @@ -17,11 +17,13 @@ package postgreslistavailableextensions import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,13 +31,13 @@ import ( const resourceType string = "postgres-list-available-extensions" const listAvailableExtensionsQuery = ` - SELECT - name, - default_version, - comment as description - FROM - pg_available_extensions - ORDER BY name; + SELECT + name, + default_version, + comment as description + FROM + pg_available_extensions + ORDER BY name; ` func init() { @@ -65,7 +67,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -76,7 +77,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) params := parameters.Parameters{} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) - // finish tool setup t := Tool{ Config: cfg, manifest: tools.Manifest{ @@ -90,7 +90,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return t, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -100,12 +99,16 @@ type Tool struct { Parameters parameters.Parameters } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } - return source.RunSQL(ctx, listAvailableExtensionsQuery, nil) + resp, err := source.RunSQL(ctx, listAvailableExtensionsQuery, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats.go b/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats.go index c78dd297d7..f01fc002a6 100644 --- a/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats.go +++ b/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats.go @@ -17,82 +17,83 @@ package postgreslistdatabasestats import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) const resourceType string = "postgres-list-database-stats" -// SQL query to list database statistics const listDatabaseStats = ` - WITH database_stats AS ( - SELECT - s.datname AS database_name, - -- Database Metadata - d.datallowconn AS is_connectable, - pg_get_userbyid(d.datdba) AS database_owner, - ts.spcname AS default_tablespace, + WITH database_stats AS ( + SELECT + s.datname AS database_name, + -- Database Metadata + d.datallowconn AS is_connectable, + pg_get_userbyid(d.datdba) AS database_owner, + ts.spcname AS default_tablespace, - -- Cache Performance - CASE - WHEN (s.blks_hit + s.blks_read) = 0 THEN 0 - ELSE round((s.blks_hit * 100.0) / (s.blks_hit + s.blks_read), 2) - END AS cache_hit_ratio_percent, - s.blks_read AS blocks_read_from_disk, - s.blks_hit AS blocks_hit_in_cache, + -- Cache Performance + CASE + WHEN (s.blks_hit + s.blks_read) = 0 THEN 0 + ELSE round((s.blks_hit * 100.0) / (s.blks_hit + s.blks_read), 2) + END AS cache_hit_ratio_percent, + s.blks_read AS blocks_read_from_disk, + s.blks_hit AS blocks_hit_in_cache, - -- Transaction Throughput - s.xact_commit, - s.xact_rollback, - round(s.xact_rollback * 100.0 / (s.xact_commit + s.xact_rollback + 1), 2) AS rollback_ratio_percent, + -- Transaction Throughput + s.xact_commit, + s.xact_rollback, + round(s.xact_rollback * 100.0 / (s.xact_commit + s.xact_rollback + 1), 2) AS rollback_ratio_percent, - -- Tuple Activity - s.tup_returned AS rows_returned_by_queries, - s.tup_fetched AS rows_fetched_by_scans, - s.tup_inserted, - s.tup_updated, - s.tup_deleted, + -- Tuple Activity + s.tup_returned AS rows_returned_by_queries, + s.tup_fetched AS rows_fetched_by_scans, + s.tup_inserted, + s.tup_updated, + s.tup_deleted, - -- Temporary File Usage - s.temp_files, - s.temp_bytes AS temp_size_bytes, + -- Temporary File Usage + s.temp_files, + s.temp_bytes AS temp_size_bytes, - -- Conflicts & Deadlocks - s.conflicts, - s.deadlocks, + -- Conflicts & Deadlocks + s.conflicts, + s.deadlocks, - -- General Info - s.numbackends AS active_connections, - s.stats_reset AS statistics_last_reset, - pg_database_size(s.datid) AS database_size_bytes - FROM - pg_stat_database s - JOIN - pg_database d ON d.oid = s.datid - JOIN - pg_tablespace ts ON ts.oid = d.dattablespace - WHERE - -- Exclude cloudsql internal databases - s.datname NOT IN ('cloudsqladmin') - -- Exclude template databases if not requested - AND ( $2::boolean IS TRUE OR d.datistemplate IS FALSE ) - ) - SELECT * - FROM database_stats - WHERE - ($1::text IS NULL OR database_name LIKE '%' || $1::text || '%') - AND ($3::text IS NULL OR database_owner LIKE '%' || $3::text || '%') - AND ($4::text IS NULL OR default_tablespace LIKE '%' || $4::text || '%') - ORDER BY - CASE WHEN $5::text = 'size' THEN database_size_bytes END DESC, - CASE WHEN $5::text = 'commit' THEN xact_commit END DESC, - database_name - LIMIT COALESCE($6::int, 10); + -- General Info + s.numbackends AS active_connections, + s.stats_reset AS statistics_last_reset, + pg_database_size(s.datid) AS database_size_bytes + FROM + pg_stat_database s + JOIN + pg_database d ON d.oid = s.datid + JOIN + pg_tablespace ts ON ts.oid = d.dattablespace + WHERE + -- Exclude cloudsql internal databases + s.datname NOT IN ('cloudsqladmin') + -- Exclude template databases if not requested + AND ( $2::boolean IS TRUE OR d.datistemplate IS FALSE ) + ) + SELECT * + FROM database_stats + WHERE + ($1::text IS NULL OR database_name LIKE '%' || $1::text || '%') + AND ($3::text IS NULL OR database_owner LIKE '%' || $3::text || '%') + AND ($4::text IS NULL OR default_tablespace LIKE '%' || $4::text || '%') + ORDER BY + CASE WHEN $5::text = 'size' THEN database_size_bytes END DESC, + CASE WHEN $5::text = 'commit' THEN xact_commit END DESC, + database_name + LIMIT COALESCE($6::int, 10); ` func init() { @@ -122,7 +123,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -164,7 +164,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -177,7 +176,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -187,21 +185,25 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listDatabaseStats, sliceParams) + resp, err := source.RunSQL(ctx, listDatabaseStats, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistindexes/postgreslistindexes.go b/internal/tools/postgres/postgreslistindexes/postgreslistindexes.go index debd2d8036..10f8b92327 100644 --- a/internal/tools/postgres/postgreslistindexes/postgreslistindexes.go +++ b/internal/tools/postgres/postgreslistindexes/postgreslistindexes.go @@ -17,11 +17,13 @@ package postgreslistindexes import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,49 +31,49 @@ import ( const resourceType string = "postgres-list-indexes" const listIndexesStatement = ` - WITH IndexDetails AS ( - SELECT - s.schemaname AS schema_name, - t.relname AS table_name, - i.relname AS index_name, - am.amname AS index_type, - ix.indisunique AS is_unique, - ix.indisprimary AS is_primary, - pg_get_indexdef(i.oid) AS index_definition, - pg_relation_size(i.oid) AS index_size_bytes, - s.idx_scan AS index_scans, - s.idx_tup_read AS tuples_read, - s.idx_tup_fetch AS tuples_fetched, - CASE - WHEN s.idx_scan > 0 THEN true - ELSE false - END AS is_used - FROM pg_catalog.pg_class t - JOIN pg_catalog.pg_index ix - ON t.oid = ix.indrelid - JOIN pg_catalog.pg_class i - ON i.oid = ix.indexrelid - JOIN pg_catalog.pg_am am - ON i.relam = am.oid - JOIN pg_catalog.pg_stat_all_indexes s - ON i.oid = s.indexrelid - WHERE - t.relkind = 'r' - AND s.schemaname NOT IN ('pg_catalog', 'information_schema', 'pg_toast') - AND s.schemaname NOT LIKE 'pg_temp_%' - ) - SELECT * - FROM IndexDetails - WHERE - ($1::text IS NULL OR schema_name LIKE '%' || $1 || '%') - AND ($2::text IS NULL OR table_name LIKE '%' || $2 || '%') - AND ($3::text IS NULL OR index_name LIKE '%' || $3 || '%') - AND ($4::boolean IS NOT TRUE OR is_used IS FALSE) - ORDER BY - schema_name, - table_name, - index_name - LIMIT COALESCE($5::int, 50); + WITH IndexDetails AS ( + SELECT + s.schemaname AS schema_name, + t.relname AS table_name, + i.relname AS index_name, + am.amname AS index_type, + ix.indisunique AS is_unique, + ix.indisprimary AS is_primary, + pg_get_indexdef(i.oid) AS index_definition, + pg_relation_size(i.oid) AS index_size_bytes, + s.idx_scan AS index_scans, + s.idx_tup_read AS tuples_read, + s.idx_tup_fetch AS tuples_fetched, + CASE + WHEN s.idx_scan > 0 THEN true + ELSE false + END AS is_used + FROM pg_catalog.pg_class t + JOIN pg_catalog.pg_index ix + ON t.oid = ix.indrelid + JOIN pg_catalog.pg_class i + ON i.oid = ix.indexrelid + JOIN pg_catalog.pg_am am + ON i.relam = am.oid + JOIN pg_catalog.pg_stat_all_indexes s + ON i.oid = s.indexrelid + WHERE + t.relkind = 'r' + AND s.schemaname NOT IN ('pg_catalog', 'information_schema', 'pg_toast') + AND s.schemaname NOT LIKE 'pg_temp_%' + ) + SELECT * + FROM IndexDetails + WHERE + ($1::text IS NULL OR schema_name LIKE '%' || $1 || '%') + AND ($2::text IS NULL OR table_name LIKE '%' || $2 || '%') + AND ($3::text IS NULL OR index_name LIKE '%' || $3 || '%') + AND ($4::boolean IS NOT TRUE OR is_used IS FALSE) + ORDER BY + schema_name, + table_name, + index_name + LIMIT COALESCE($5::int, 50); ` func init() { @@ -101,7 +103,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -122,7 +123,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -135,7 +135,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -149,21 +148,25 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listIndexesStatement, sliceParams) + resp, err := source.RunSQL(ctx, listIndexesStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistinstalledextensions/postgreslistinstalledextensions.go b/internal/tools/postgres/postgreslistinstalledextensions/postgreslistinstalledextensions.go index 8273ae9247..cdac40ab0e 100644 --- a/internal/tools/postgres/postgreslistinstalledextensions/postgreslistinstalledextensions.go +++ b/internal/tools/postgres/postgreslistinstalledextensions/postgreslistinstalledextensions.go @@ -17,11 +17,13 @@ package postgreslistinstalledextensions import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,24 +31,24 @@ import ( const resourceType string = "postgres-list-installed-extensions" const listAvailableExtensionsQuery = ` - SELECT - e.extname AS name, - e.extversion AS version, - n.nspname AS schema, - pg_get_userbyid(e.extowner) AS owner, - c.description AS description - FROM - pg_catalog.pg_extension e - LEFT JOIN - pg_catalog.pg_namespace n - ON - n.oid = e.extnamespace - LEFT JOIN - pg_catalog.pg_description c - ON - c.objoid = e.oid - AND c.classoid = 'pg_catalog.pg_extension'::pg_catalog.regclass - ORDER BY 1; + SELECT + e.extname AS name, + e.extversion AS version, + n.nspname AS schema, + pg_get_userbyid(e.extowner) AS owner, + c.description AS description + FROM + pg_catalog.pg_extension e + LEFT JOIN + pg_catalog.pg_namespace n + ON + n.oid = e.extnamespace + LEFT JOIN + pg_catalog.pg_description c + ON + c.objoid = e.oid + AND c.classoid = 'pg_catalog.pg_extension'::pg_catalog.regclass + ORDER BY 1; ` func init() { @@ -76,7 +78,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -87,7 +88,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) params := parameters.Parameters{} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) - // finish tool setup t := Tool{ Config: cfg, manifest: tools.Manifest{ @@ -100,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return t, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -109,12 +108,16 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } - return source.RunSQL(ctx, listAvailableExtensionsQuery, nil) + resp, err := source.RunSQL(ctx, listAvailableExtensionsQuery, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { @@ -145,7 +148,6 @@ func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, return "Authorization", nil } -// This tool does not have parameters, so return an empty set. func (t Tool) GetParameters() parameters.Parameters { return parameters.Parameters{} } diff --git a/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go b/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go index b29cb1e57e..bac4b6a01b 100644 --- a/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go +++ b/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go @@ -17,11 +17,13 @@ package postgreslistlocks import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,7 +31,7 @@ import ( const resourceType string = "postgres-list-locks" const listLocks = ` - SELECT + SELECT locked.pid, locked.usename, locked.query, @@ -76,7 +78,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -93,7 +94,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -106,7 +106,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -120,21 +119,25 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listLocks, sliceParams) + resp, err := source.RunSQL(ctx, listLocks, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings.go b/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings.go index 6d10837830..85d9dd0e35 100644 --- a/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings.go +++ b/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings.go @@ -17,11 +17,13 @@ package postgreslistpgsettings import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,7 +31,7 @@ import ( const resourceType string = "postgres-list-pg-settings" const listPgSettingsStatement = ` - SELECT + SELECT name, setting AS current_value, unit, @@ -41,10 +43,10 @@ const listPgSettingsStatement = ` ELSE 'No' END AS requires_restart - FROM pg_settings - WHERE ($1::text IS NULL OR name LIKE '%' || $1::text || '%') - ORDER BY name - LIMIT COALESCE($2::int, 50); + FROM pg_settings + WHERE ($1::text IS NULL OR name LIKE '%' || $1::text || '%') + ORDER BY name + LIMIT COALESCE($2::int, 50); ` func init() { @@ -74,7 +76,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -92,7 +93,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -105,7 +105,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -115,19 +114,23 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listPgSettingsStatement, sliceParams) + resp, err := source.RunSQL(ctx, listPgSettingsStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables.go b/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables.go index db6af2c62f..a5ee63db16 100644 --- a/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables.go +++ b/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables.go @@ -17,11 +17,13 @@ package postgreslistpublicationtables import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,33 +31,33 @@ import ( const resourceType string = "postgres-list-publication-tables" const listPublicationTablesStatement = ` - WITH - publication_details AS ( - SELECT - pt.pubname AS publication_name, - pt.schemaname AS schema_name, - pt.tablename AS table_name, - -- Definition details - p.puballtables AS publishes_all_tables, - p.pubinsert AS publishes_inserts, - p.pubupdate AS publishes_updates, - p.pubdelete AS publishes_deletes, - p.pubtruncate AS publishes_truncates, - -- Owner information - pg_catalog.pg_get_userbyid(p.pubowner) AS publication_owner - FROM pg_catalog.pg_publication_tables pt - JOIN pg_catalog.pg_publication p - ON pt.pubname = p.pubname - ) - SELECT * - FROM publication_details - WHERE - (NULLIF(TRIM($1::text), '') IS NULL OR table_name = ANY(regexp_split_to_array(TRIM($1::text), '\s*,\s*'))) - AND (NULLIF(TRIM($2::text), '') IS NULL OR publication_name = ANY(regexp_split_to_array(TRIM($2::text), '\s*,\s*'))) - AND (NULLIF(TRIM($3::text), '') IS NULL OR schema_name = ANY(regexp_split_to_array(TRIM($3::text), '\s*,\s*'))) - ORDER BY - publication_name, schema_name, table_name - LIMIT COALESCE($4::int, 50); + WITH + publication_details AS ( + SELECT + pt.pubname AS publication_name, + pt.schemaname AS schema_name, + pt.tablename AS table_name, + -- Definition details + p.puballtables AS publishes_all_tables, + p.pubinsert AS publishes_inserts, + p.pubupdate AS publishes_updates, + p.pubdelete AS publishes_deletes, + p.pubtruncate AS publishes_truncates, + -- Owner information + pg_catalog.pg_get_userbyid(p.pubowner) AS publication_owner + FROM pg_catalog.pg_publication_tables pt + JOIN pg_catalog.pg_publication p + ON pt.pubname = p.pubname + ) + SELECT * + FROM publication_details + WHERE + (NULLIF(TRIM($1::text), '') IS NULL OR table_name = ANY(regexp_split_to_array(TRIM($1::text), '\s*,\s*'))) + AND (NULLIF(TRIM($2::text), '') IS NULL OR publication_name = ANY(regexp_split_to_array(TRIM($2::text), '\s*,\s*'))) + AND (NULLIF(TRIM($3::text), '') IS NULL OR schema_name = ANY(regexp_split_to_array(TRIM($3::text), '\s*,\s*'))) + ORDER BY + publication_name, schema_name, table_name + LIMIT COALESCE($4::int, 50); ` func init() { @@ -85,7 +87,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -105,7 +106,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -118,7 +118,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -128,20 +127,24 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listPublicationTablesStatement, sliceParams) + resp, err := source.RunSQL(ctx, listPublicationTablesStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistquerystats/postgreslistquerystats.go b/internal/tools/postgres/postgreslistquerystats/postgreslistquerystats.go index 20303abfc3..f54c3dc554 100644 --- a/internal/tools/postgres/postgreslistquerystats/postgreslistquerystats.go +++ b/internal/tools/postgres/postgreslistquerystats/postgreslistquerystats.go @@ -17,11 +17,13 @@ package postgreslistquerystats import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,7 +31,7 @@ import ( const resourceType string = "postgres-list-query-stats" const listQueryStats = ` - SELECT + SELECT d.datname, s.query, s.calls, @@ -75,7 +77,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -95,7 +96,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -108,7 +108,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -122,19 +121,23 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listQueryStats, sliceParams) + resp, err := source.RunSQL(ctx, listQueryStats, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistroles/postgreslistroles.go b/internal/tools/postgres/postgreslistroles/postgreslistroles.go index c14b652c58..20cf87c20f 100644 --- a/internal/tools/postgres/postgreslistroles/postgreslistroles.go +++ b/internal/tools/postgres/postgreslistroles/postgreslistroles.go @@ -17,11 +17,13 @@ package postgreslistroles import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,45 +31,45 @@ import ( const resourceType string = "postgres-list-roles" const listRolesStatement = ` - WITH RoleDetails AS ( - SELECT - r.rolname AS role_name, - r.oid AS oid, - r.rolconnlimit AS connection_limit, - r.rolsuper AS is_superuser, - r.rolinherit AS inherits_privileges, - r.rolcreaterole AS can_create_roles, - r.rolcreatedb AS can_create_db, - r.rolcanlogin AS can_login, - r.rolreplication AS is_replication_role, - r.rolbypassrls AS bypass_rls, - r.rolvaliduntil AS valid_until, - -- List of roles that belong to this role (Direct Members) - ARRAY( - SELECT m_r.rolname - FROM pg_auth_members pam - JOIN pg_roles m_r ON pam.member = m_r.oid - WHERE pam.roleid = r.oid - ) AS direct_members, - -- List of roles that this role belongs to (Member Of) - ARRAY( - SELECT g_r.rolname - FROM pg_auth_members pam - JOIN pg_roles g_r ON pam.roleid = g_r.oid - WHERE pam.member = r.oid - ) AS member_of - FROM pg_roles r - -- Exclude system and internal roles - WHERE r.rolname NOT LIKE 'cloudsql%' - AND r.rolname NOT LIKE 'alloydb_%' - AND r.rolname NOT LIKE 'pg_%' - ) - SELECT * - FROM RoleDetails - WHERE - ($1::text IS NULL OR role_name LIKE '%' || $1 || '%') - ORDER BY role_name - LIMIT COALESCE($2::int, 50); + WITH RoleDetails AS ( + SELECT + r.rolname AS role_name, + r.oid AS oid, + r.rolconnlimit AS connection_limit, + r.rolsuper AS is_superuser, + r.rolinherit AS inherits_privileges, + r.rolcreaterole AS can_create_roles, + r.rolcreatedb AS can_create_db, + r.rolcanlogin AS can_login, + r.rolreplication AS is_replication_role, + r.rolbypassrls AS bypass_rls, + r.rolvaliduntil AS valid_until, + -- List of roles that belong to this role (Direct Members) + ARRAY( + SELECT m_r.rolname + FROM pg_auth_members pam + JOIN pg_roles m_r ON pam.member = m_r.oid + WHERE pam.roleid = r.oid + ) AS direct_members, + -- List of roles that this role belongs to (Member Of) + ARRAY( + SELECT g_r.rolname + FROM pg_auth_members pam + JOIN pg_roles g_r ON pam.roleid = g_r.oid + WHERE pam.member = r.oid + ) AS member_of + FROM pg_roles r + -- Exclude system and internal roles + WHERE r.rolname NOT LIKE 'cloudsql%' + AND r.rolname NOT LIKE 'alloydb_%' + AND r.rolname NOT LIKE 'pg_%' + ) + SELECT * + FROM RoleDetails + WHERE + ($1::text IS NULL OR role_name LIKE '%' || $1 || '%') + ORDER BY role_name + LIMIT COALESCE($2::int, 50); ` func init() { @@ -97,7 +99,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -116,7 +117,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -129,7 +129,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -143,20 +142,24 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listRolesStatement, sliceParams) + resp, err := source.RunSQL(ctx, listRolesStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistschemas/postgreslistschemas.go b/internal/tools/postgres/postgreslistschemas/postgreslistschemas.go index dbf2a8b367..b1ff208f08 100644 --- a/internal/tools/postgres/postgreslistschemas/postgreslistschemas.go +++ b/internal/tools/postgres/postgreslistschemas/postgreslistschemas.go @@ -17,11 +17,13 @@ package postgreslistschemas import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,7 +31,7 @@ import ( const resourceType string = "postgres-list-schemas" const listSchemasStatement = ` - WITH + WITH schema_grants AS ( SELECT schema_oid, jsonb_object_agg(grantee, privileges) AS grants FROM @@ -52,27 +54,27 @@ const listSchemasStatement = ` SELECT n.nspname AS schema_name, pg_catalog.pg_get_userbyid(n.nspowner) AS owner, - COALESCE(sg.grants, '{}'::jsonb) AS grants, - ( - SELECT COUNT(*) - FROM pg_catalog.pg_class c - WHERE c.relnamespace = n.oid AND c.relkind = 'r' - ) AS tables, - ( - SELECT COUNT(*) - FROM pg_catalog.pg_class c - WHERE c.relnamespace = n.oid AND c.relkind = 'v' - ) AS views, - (SELECT COUNT(*) FROM pg_catalog.pg_proc p WHERE p.pronamespace = n.oid) - AS functions + COALESCE(sg.grants, '{}'::jsonb) AS grants, + ( + SELECT COUNT(*) + FROM pg_catalog.pg_class c + WHERE c.relnamespace = n.oid AND c.relkind = 'r' + ) AS tables, + ( + SELECT COUNT(*) + FROM pg_catalog.pg_class c + WHERE c.relnamespace = n.oid AND c.relkind = 'v' + ) AS views, + (SELECT COUNT(*) FROM pg_catalog.pg_proc p WHERE p.pronamespace = n.oid) + AS functions FROM pg_catalog.pg_namespace n LEFT JOIN schema_grants sg ON n.oid = sg.schema_oid ) - SELECT * - FROM all_schemas - -- Exclude system schemas and temporary schemas created per session. - WHERE + SELECT * + FROM all_schemas + -- Exclude system and temporary schemas created per session. + WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast') AND schema_name NOT LIKE 'pg_temp_%' AND schema_name NOT LIKE 'pg_toast_temp_%' @@ -109,7 +111,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -128,7 +129,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -141,7 +141,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -151,20 +150,24 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listSchemasStatement, sliceParams) + resp, err := source.RunSQL(ctx, listSchemasStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistsequences/postgreslistsequences.go b/internal/tools/postgres/postgreslistsequences/postgreslistsequences.go index bee44edbca..aca352317c 100644 --- a/internal/tools/postgres/postgreslistsequences/postgreslistsequences.go +++ b/internal/tools/postgres/postgreslistsequences/postgreslistsequences.go @@ -17,11 +17,13 @@ package postgreslistsequences import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,22 +31,22 @@ import ( const resourceType string = "postgres-list-sequences" const listSequencesStatement = ` - SELECT - sequencename as sequence_name, - schemaname as schema_name, - sequenceowner as sequence_owner, - data_type, - start_value, - min_value, - max_value, - increment_by, - last_value - FROM pg_sequences - WHERE - ($1::text IS NULL OR schemaname LIKE '%' || $1 || '%') - AND ($2::text IS NULL OR sequencename LIKE '%' || $2 || '%') - ORDER BY schema_name, sequence_name - LIMIT COALESCE($3::int, 50); + SELECT + sequencename as sequence_name, + schemaname as schema_name, + sequenceowner as sequence_owner, + data_type, + start_value, + min_value, + max_value, + increment_by, + last_value + FROM pg_sequences + WHERE + ($1::text IS NULL OR schemaname LIKE '%' || $1 || '%') + AND ($2::text IS NULL OR sequencename LIKE '%' || $2 || '%') + ORDER BY schema_name, sequence_name + LIMIT COALESCE($3::int, 50); ` @@ -75,7 +77,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -94,7 +95,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -107,7 +107,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -121,21 +120,25 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listSequencesStatement, sliceParams) + resp, err := source.RunSQL(ctx, listSequencesStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgresliststoredprocedure/postgresliststoredprocedure.go b/internal/tools/postgres/postgresliststoredprocedure/postgresliststoredprocedure.go index f8d9891cac..96c727a020 100644 --- a/internal/tools/postgres/postgresliststoredprocedure/postgresliststoredprocedure.go +++ b/internal/tools/postgres/postgresliststoredprocedure/postgresliststoredprocedure.go @@ -17,6 +17,7 @@ package postgresliststoredprocedure import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -25,6 +26,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -32,7 +34,7 @@ import ( const resourceType string = "postgres-list-stored-procedure" const listStoredProcedure = ` - SELECT + SELECT n.nspname AS schema_name, p.proname AS name, r.rolname AS owner, @@ -85,7 +87,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -118,7 +119,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -132,7 +132,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -147,18 +146,18 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() results, err := t.pool.Query(ctx, listStoredProcedure, sliceParams...) if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) + return nil, util.ProcessGeneralError(err) } defer results.Close() @@ -168,7 +167,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para for results.Next() { values, err := results.Values() if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) + return nil, util.NewClientServerError("unable to parse row", http.StatusInternalServerError, err) } rowMap := make(map[string]any) for i, field := range fields { diff --git a/internal/tools/postgres/postgreslisttables/postgreslisttables.go b/internal/tools/postgres/postgreslisttables/postgreslisttables.go index da3ea82af0..70a4b594e9 100644 --- a/internal/tools/postgres/postgreslisttables/postgreslisttables.go +++ b/internal/tools/postgres/postgreslisttables/postgreslisttables.go @@ -17,11 +17,13 @@ package postgreslisttables import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,81 +31,81 @@ import ( const resourceType string = "postgres-list-tables" const listTablesStatement = ` - WITH desired_relkinds AS ( - SELECT ARRAY['r', 'p']::char[] AS kinds -- Always consider both 'TABLE' and 'PARTITIONED TABLE' - ), - table_info AS ( - SELECT - t.oid AS table_oid, - ns.nspname AS schema_name, - t.relname AS table_name, - pg_get_userbyid(t.relowner) AS table_owner, - obj_description(t.oid, 'pg_class') AS table_comment, - t.relkind AS object_kind - FROM - pg_class t - JOIN - pg_namespace ns ON ns.oid = t.relnamespace - CROSS JOIN desired_relkinds dk - WHERE - t.relkind = ANY(dk.kinds) -- Filter by selected table relkinds ('r', 'p') - AND (NULLIF(TRIM($1), '') IS NULL OR t.relname = ANY(string_to_array($1,','))) -- $1 is object_names - AND ns.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast','google_ml') - AND ns.nspname NOT LIKE 'pg_temp_%' AND ns.nspname NOT LIKE 'pg_toast_temp_%' - ), - columns_info AS ( - SELECT - att.attrelid AS table_oid, att.attname AS column_name, format_type(att.atttypid, att.atttypmod) AS data_type, - att.attnum AS column_ordinal_position, att.attnotnull AS is_not_nullable, - pg_get_expr(ad.adbin, ad.adrelid) AS column_default, col_description(att.attrelid, att.attnum) AS column_comment - FROM pg_attribute att LEFT JOIN pg_attrdef ad ON att.attrelid = ad.adrelid AND att.attnum = ad.adnum - JOIN table_info ti ON att.attrelid = ti.table_oid WHERE att.attnum > 0 AND NOT att.attisdropped - ), - constraints_info AS ( - SELECT - con.conrelid AS table_oid, con.conname AS constraint_name, pg_get_constraintdef(con.oid) AS constraint_definition, - CASE con.contype WHEN 'p' THEN 'PRIMARY KEY' WHEN 'f' THEN 'FOREIGN KEY' WHEN 'u' THEN 'UNIQUE' WHEN 'c' THEN 'CHECK' ELSE con.contype::text END AS constraint_type, - (SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.conkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.conrelid AND att.attnum = u.attnum) AS constraint_columns, - NULLIF(con.confrelid, 0)::regclass AS foreign_key_referenced_table, - (SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.confkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.confrelid AND att.attnum = u.attnum WHERE con.contype = 'f') AS foreign_key_referenced_columns - FROM pg_constraint con JOIN table_info ti ON con.conrelid = ti.table_oid - ), - indexes_info AS ( - SELECT - idx.indrelid AS table_oid, ic.relname AS index_name, pg_get_indexdef(idx.indexrelid) AS index_definition, - idx.indisunique AS is_unique, idx.indisprimary AS is_primary, am.amname AS index_method, - (SELECT array_agg(att.attname ORDER BY u.ord) FROM unnest(idx.indkey::int[]) WITH ORDINALITY AS u(colidx, ord) LEFT JOIN pg_attribute att ON att.attrelid = idx.indrelid AND att.attnum = u.colidx WHERE u.colidx <> 0) AS index_columns - FROM pg_index idx JOIN pg_class ic ON ic.oid = idx.indexrelid JOIN pg_am am ON am.oid = ic.relam JOIN table_info ti ON idx.indrelid = ti.table_oid - ), - triggers_info AS ( - SELECT tg.tgrelid AS table_oid, tg.tgname AS trigger_name, pg_get_triggerdef(tg.oid) AS trigger_definition, tg.tgenabled AS trigger_enabled_state - FROM pg_trigger tg JOIN table_info ti ON tg.tgrelid = ti.table_oid WHERE NOT tg.tgisinternal - ) - SELECT - ti.schema_name, - ti.table_name AS object_name, - CASE - WHEN $2 = 'simple' THEN - -- IF format is 'simple', return basic JSON - json_build_object('name', ti.table_name) - ELSE - json_build_object( - 'schema_name', ti.schema_name, - 'object_name', ti.table_name, - 'object_type', CASE ti.object_kind - WHEN 'r' THEN 'TABLE' - WHEN 'p' THEN 'PARTITIONED TABLE' - ELSE ti.object_kind::text -- Should not happen due to WHERE clause - END, - 'owner', ti.table_owner, - 'comment', ti.table_comment, - 'columns', COALESCE((SELECT json_agg(json_build_object('column_name',ci.column_name,'data_type',ci.data_type,'ordinal_position',ci.column_ordinal_position,'is_not_nullable',ci.is_not_nullable,'column_default',ci.column_default,'column_comment',ci.column_comment) ORDER BY ci.column_ordinal_position) FROM columns_info ci WHERE ci.table_oid = ti.table_oid), '[]'::json), - 'constraints', COALESCE((SELECT json_agg(json_build_object('constraint_name',cons.constraint_name,'constraint_type',cons.constraint_type,'constraint_definition',cons.constraint_definition,'constraint_columns',cons.constraint_columns,'foreign_key_referenced_table',cons.foreign_key_referenced_table,'foreign_key_referenced_columns',cons.foreign_key_referenced_columns)) FROM constraints_info cons WHERE cons.table_oid = ti.table_oid), '[]'::json), - 'indexes', COALESCE((SELECT json_agg(json_build_object('index_name',ii.index_name,'index_definition',ii.index_definition,'is_unique',ii.is_unique,'is_primary',ii.is_primary,'index_method',ii.index_method,'index_columns',ii.index_columns)) FROM indexes_info ii WHERE ii.table_oid = ti.table_oid), '[]'::json), - 'triggers', COALESCE((SELECT json_agg(json_build_object('trigger_name',tri.trigger_name,'trigger_definition',tri.trigger_definition,'trigger_enabled_state',tri.trigger_enabled_state)) FROM triggers_info tri WHERE tri.table_oid = ti.table_oid), '[]'::json) - ) - END AS object_details - FROM table_info ti ORDER BY ti.schema_name, ti.table_name; + WITH desired_relkinds AS ( + SELECT ARRAY['r', 'p']::char[] AS kinds -- Always consider both 'TABLE' and 'PARTITIONED TABLE' + ), + table_info AS ( + SELECT + t.oid AS table_oid, + ns.nspname AS schema_name, + t.relname AS table_name, + pg_get_userbyid(t.relowner) AS table_owner, + obj_description(t.oid, 'pg_class') AS table_comment, + t.relkind AS object_kind + FROM + pg_class t + JOIN + pg_namespace ns ON ns.oid = t.relnamespace + CROSS JOIN desired_relkinds dk + WHERE + t.relkind = ANY(dk.kinds) -- Filter by selected table relkinds ('r', 'p') + AND (NULLIF(TRIM($1), '') IS NULL OR t.relname = ANY(string_to_array($1,','))) -- $1 is object_names + AND ns.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast','google_ml') + AND ns.nspname NOT LIKE 'pg_temp_%' AND ns.nspname NOT LIKE 'pg_toast_temp_%' + ), + columns_info AS ( + SELECT + att.attrelid AS table_oid, att.attname AS column_name, format_type(att.atttypid, att.atttypmod) AS data_type, + att.attnum AS column_ordinal_position, att.attnotnull AS is_not_nullable, + pg_get_expr(ad.adbin, ad.adrelid) AS column_default, col_description(att.attrelid, att.attnum) AS column_comment + FROM pg_attribute att LEFT JOIN pg_attrdef ad ON att.attrelid = ad.adrelid AND att.attnum = ad.adnum + JOIN table_info ti ON att.attrelid = ti.table_oid WHERE att.attnum > 0 AND NOT att.attisdropped + ), + constraints_info AS ( + SELECT + con.conrelid AS table_oid, con.conname AS constraint_name, pg_get_constraintdef(con.oid) AS constraint_definition, + CASE con.contype WHEN 'p' THEN 'PRIMARY KEY' WHEN 'f' THEN 'FOREIGN KEY' WHEN 'u' THEN 'UNIQUE' WHEN 'c' THEN 'CHECK' ELSE con.contype::text END AS constraint_type, + (SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.conkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.conrelid AND att.attnum = u.attnum) AS constraint_columns, + NULLIF(con.confrelid, 0)::regclass AS foreign_key_referenced_table, + (SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.confkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.confrelid AND att.attnum = u.attnum WHERE con.contype = 'f') AS foreign_key_referenced_columns + FROM pg_constraint con JOIN table_info ti ON con.conrelid = ti.table_oid + ), + indexes_info AS ( + SELECT + idx.indrelid AS table_oid, ic.relname AS index_name, pg_get_indexdef(idx.indexrelid) AS index_definition, + idx.indisunique AS is_unique, idx.indisprimary AS is_primary, am.amname AS index_method, + (SELECT array_agg(att.attname ORDER BY u.ord) FROM unnest(idx.indkey::int[]) WITH ORDINALITY AS u(colidx, ord) LEFT JOIN pg_attribute att ON att.attrelid = idx.indrelid AND att.attnum = u.colidx WHERE u.colidx <> 0) AS index_columns + FROM pg_index idx JOIN pg_class ic ON ic.oid = idx.indexrelid JOIN pg_am am ON am.oid = ic.relam JOIN table_info ti ON idx.indrelid = ti.table_oid + ), + triggers_info AS ( + SELECT tg.tgrelid AS table_oid, tg.tgname AS trigger_name, pg_get_triggerdef(tg.oid) AS trigger_definition, tg.tgenabled AS trigger_enabled_state + FROM pg_trigger tg JOIN table_info ti ON tg.tgrelid = ti.table_oid WHERE NOT tg.tgisinternal + ) + SELECT + ti.schema_name, + ti.table_name AS object_name, + CASE + WHEN $2 = 'simple' THEN + -- IF format is 'simple', return basic JSON + json_build_object('name', ti.table_name) + ELSE + json_build_object( + 'schema_name', ti.schema_name, + 'object_name', ti.table_name, + 'object_type', CASE ti.object_kind + WHEN 'r' THEN 'TABLE' + WHEN 'p' THEN 'PARTITIONED TABLE' + ELSE ti.object_kind::text -- Should not happen due to WHERE clause + END, + 'owner', ti.table_owner, + 'comment', ti.table_comment, + 'columns', COALESCE((SELECT json_agg(json_build_object('column_name',ci.column_name,'data_type',ci.data_type,'ordinal_position',ci.column_ordinal_position,'is_not_nullable',ci.is_not_nullable,'column_default',ci.column_default,'column_comment',ci.column_comment) ORDER BY ci.column_ordinal_position) FROM columns_info ci WHERE ci.table_oid = ti.table_oid), '[]'::json), + 'constraints', COALESCE((SELECT json_agg(json_build_object('constraint_name',cons.constraint_name,'constraint_type',cons.constraint_type,'constraint_definition',cons.constraint_definition,'constraint_columns',cons.constraint_columns,'foreign_key_referenced_table',cons.foreign_key_referenced_table,'foreign_key_referenced_columns',cons.foreign_key_referenced_columns)) FROM constraints_info cons WHERE cons.table_oid = ti.table_oid), '[]'::json), + 'indexes', COALESCE((SELECT json_agg(json_build_object('index_name',ii.index_name,'index_definition',ii.index_definition,'is_unique',ii.is_unique,'is_primary',ii.is_primary,'index_method',ii.index_method,'index_columns',ii.index_columns)) FROM indexes_info ii WHERE ii.table_oid = ti.table_oid), '[]'::json), + 'triggers', COALESCE((SELECT json_agg(json_build_object('trigger_name',tri.trigger_name,'trigger_definition',tri.trigger_definition,'trigger_enabled_state',tri.trigger_enabled_state)) FROM triggers_info tri WHERE tri.table_oid = ti.table_oid), '[]'::json) + ) + END AS object_details + FROM table_info ti ORDER BY ti.schema_name, ti.table_name; ` func init() { @@ -133,7 +135,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -158,7 +159,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return t, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -168,31 +168,31 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() tableNames, ok := paramsMap["table_names"].(string) if !ok { - return nil, fmt.Errorf("invalid 'table_names' parameter; expected a string") + return nil, util.NewAgentError("invalid 'table_names' parameter; expected a string", nil) } outputFormat, _ := paramsMap["output_format"].(string) if outputFormat != "simple" && outputFormat != "detailed" { - return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat) + return nil, util.NewAgentError(fmt.Sprintf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat), nil) } resp, err := source.RunSQL(ctx, listTablesStatement, []any{tableNames, outputFormat}) if err != nil { - return nil, err + return nil, util.ProcessGeneralError(err) } resSlice, ok := resp.([]any) if !ok || len(resSlice) == 0 { return []any{}, nil } - return resp, err + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces.go b/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces.go index 588caf8117..a5a3296dec 100644 --- a/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces.go +++ b/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces.go @@ -17,11 +17,13 @@ package postgreslisttablespaces import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,29 +31,29 @@ import ( const resourceType string = "postgres-list-tablespaces" const listTableSpacesStatement = ` - WITH - tablespace_info AS ( - SELECT - spcname AS tablespace_name, - pg_catalog.pg_get_userbyid(spcowner) AS owner_name, - CASE - WHEN pg_catalog.has_tablespace_privilege(oid, 'CREATE') THEN pg_tablespace_size(oid) - ELSE NULL - END AS size_in_bytes, - oid, - spcacl, - spcoptions - FROM - pg_tablespace - ) - SELECT * - FROM - tablespace_info - WHERE - ($1::text IS NULL OR tablespace_name LIKE '%' || $1::text || '%') - ORDER BY - tablespace_name - LIMIT COALESCE($2::int, 50); + WITH + tablespace_info AS ( + SELECT + spcname AS tablespace_name, + pg_catalog.pg_get_userbyid(spcowner) AS owner_name, + CASE + WHEN pg_catalog.has_tablespace_privilege(oid, 'CREATE') THEN pg_tablespace_size(oid) + ELSE NULL + END AS size_in_bytes, + oid, + spcacl, + spcoptions + FROM + pg_tablespace + ) + SELECT * + FROM + tablespace_info + WHERE + ($1::text IS NULL OR tablespace_name LIKE '%' || $1::text || '%') + ORDER BY + tablespace_name + LIMIT COALESCE($2::int, 50); ` func init() { @@ -81,7 +83,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -99,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -112,7 +112,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -126,24 +125,28 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() tablespaceName, ok := paramsMap["tablespace_name"].(string) if !ok { - return nil, fmt.Errorf("invalid 'tablespace_name' parameter; expected a string") + return nil, util.NewAgentError("invalid 'tablespace_name' parameter; expected a string", nil) } limit, ok := paramsMap["limit"].(int) if !ok { - return nil, fmt.Errorf("invalid 'limit' parameter; expected an integer") + return nil, util.NewAgentError("invalid 'limit' parameter; expected an integer", nil) } - return source.RunSQL(ctx, listTableSpacesStatement, []any{tablespaceName, limit}) + resp, err := source.RunSQL(ctx, listTableSpacesStatement, []any{tablespaceName, limit}) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats.go b/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats.go index 8e5d8e3309..13c4a9b05c 100644 --- a/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats.go +++ b/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats.go @@ -17,11 +17,13 @@ package postgreslisttablestats import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,7 +31,7 @@ import ( const resourceType string = "postgres-list-table-stats" const listTableStats = ` - WITH table_stats AS ( + WITH table_stats AS ( SELECT s.schemaname AS schema_name, s.relname AS table_name, @@ -102,7 +104,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -121,19 +122,18 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) if cfg.Description == "" { cfg.Description = `Lists the user table statistics in the database ordered by number of - sequential scans with a default limit of 50 rows. Returns the following - columns: schema name, table name, table size in bytes, number of - sequential scans, number of index scans, idx_scan_ratio_percent (showing - the percentage of total scans that utilized an index, where a low ratio - indicates missing or ineffective indexes), number of live rows, number - of dead rows, dead_row_ratio_percent (indicating potential table bloat), - total number of rows inserted, updated, and deleted, the timestamps - for the last_vacuum, last_autovacuum, and last_autoanalyze operations.` + sequential scans with a default limit of 50 rows. Returns the following + columns: schema name, table name, table size in bytes, number of + sequential scans, number of index scans, idx_scan_ratio_percent (showing + the percentage of total scans that utilized an index, where a low ratio + indicates missing or ineffective indexes), number of live rows, number + of dead rows, dead_row_ratio_percent (indicating potential table bloat), + total number of rows inserted, updated, and deleted, the timestamps + for the last_vacuum, last_autovacuum, and last_autoanalyze operations.` } mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -146,7 +146,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -160,21 +159,25 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listTableStats, sliceParams) + resp, err := source.RunSQL(ctx, listTableStats, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslisttriggers/postgreslisttriggers.go b/internal/tools/postgres/postgreslisttriggers/postgreslisttriggers.go index 810f242e62..63889bfb46 100644 --- a/internal/tools/postgres/postgreslisttriggers/postgreslisttriggers.go +++ b/internal/tools/postgres/postgreslisttriggers/postgreslisttriggers.go @@ -17,11 +17,13 @@ package postgreslisttriggers import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,49 +31,49 @@ import ( const resourceType string = "postgres-list-triggers" const listTriggersStatement = ` - WITH - trigger_list AS ( - SELECT - t.tgname AS trigger_name, - n.nspname AS schema_name, - c.relname AS table_name, - CASE t.tgenabled - WHEN 'O' THEN 'ENABLED' - WHEN 'D' THEN 'DISABLED' - WHEN 'R' THEN 'REPLICA' - WHEN 'A' THEN 'ALWAYS' - END AS status, - CASE - WHEN (t.tgtype::int & 2) = 2 THEN 'BEFORE' - WHEN (t.tgtype::int & 64) = 64 THEN 'INSTEAD OF' - ELSE 'AFTER' - END AS timing, - concat_ws( - ', ', - CASE WHEN (t.tgtype::int & 4) = 4 THEN 'INSERT' END, - CASE WHEN (t.tgtype::int & 16) = 16 THEN 'UPDATE' END, - CASE WHEN (t.tgtype::int & 8) = 8 THEN 'DELETE' END, - CASE WHEN (t.tgtype::int & 32) = 32 THEN 'TRUNCATE' END) AS events, - CASE WHEN (t.tgtype::int & 1) = 1 THEN 'ROW' ELSE 'STATEMENT' END AS activation_level, - p.proname AS function_name, - pg_get_triggerdef(t.oid) AS definition - FROM pg_trigger t - JOIN pg_class c - ON t.tgrelid = c.oid - JOIN pg_namespace n - ON c.relnamespace = n.oid - LEFT JOIN pg_proc p - ON t.tgfoid = p.oid - WHERE NOT t.tgisinternal - ) - SELECT * - FROM trigger_list - WHERE - ($1::text IS NULL OR trigger_name LIKE '%' || $1::text || '%') - AND ($2::text IS NULL OR schema_name LIKE '%' || $2::text || '%') - AND ($3::text IS NULL OR table_name LIKE '%' || $3::text || '%') - ORDER BY schema_name, table_name, trigger_name - LIMIT COALESCE($4::int, 50); + WITH + trigger_list AS ( + SELECT + t.tgname AS trigger_name, + n.nspname AS schema_name, + c.relname AS table_name, + CASE t.tgenabled + WHEN 'O' THEN 'ENABLED' + WHEN 'D' THEN 'DISABLED' + WHEN 'R' THEN 'REPLICA' + WHEN 'A' THEN 'ALWAYS' + END AS status, + CASE + WHEN (t.tgtype::int & 2) = 2 THEN 'BEFORE' + WHEN (t.tgtype::int & 64) = 64 THEN 'INSTEAD OF' + ELSE 'AFTER' + END AS timing, + concat_ws( + ', ', + CASE WHEN (t.tgtype::int & 4) = 4 THEN 'INSERT' END, + CASE WHEN (t.tgtype::int & 16) = 16 THEN 'UPDATE' END, + CASE WHEN (t.tgtype::int & 8) = 8 THEN 'DELETE' END, + CASE WHEN (t.tgtype::int & 32) = 32 THEN 'TRUNCATE' END) AS events, + CASE WHEN (t.tgtype::int & 1) = 1 THEN 'ROW' ELSE 'STATEMENT' END AS activation_level, + p.proname AS function_name, + pg_get_triggerdef(t.oid) AS definition + FROM pg_trigger t + JOIN pg_class c + ON t.tgrelid = c.oid + JOIN pg_namespace n + ON c.relnamespace = n.oid + LEFT JOIN pg_proc p + ON t.tgfoid = p.oid + WHERE NOT t.tgisinternal + ) + SELECT * + FROM trigger_list + WHERE + ($1::text IS NULL OR trigger_name LIKE '%' || $1::text || '%') + AND ($2::text IS NULL OR schema_name LIKE '%' || $2::text || '%') + AND ($3::text IS NULL OR table_name LIKE '%' || $3::text || '%') + ORDER BY schema_name, table_name, trigger_name + LIMIT COALESCE($4::int, 50); ` func init() { @@ -101,7 +103,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -121,7 +122,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -134,7 +134,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -148,20 +147,24 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listTriggersStatement, sliceParams) + resp, err := source.RunSQL(ctx, listTriggersStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistviews/postgreslistviews.go b/internal/tools/postgres/postgreslistviews/postgreslistviews.go index e2d49691fa..e4359b9759 100644 --- a/internal/tools/postgres/postgreslistviews/postgreslistviews.go +++ b/internal/tools/postgres/postgreslistviews/postgreslistviews.go @@ -17,11 +17,13 @@ package postgreslistviews import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,24 +31,24 @@ import ( const resourceType string = "postgres-list-views" const listViewsStatement = ` - WITH list_views AS ( - SELECT - schemaname AS schema_name, - viewname AS view_name, - viewowner AS owner_name, - definition - FROM pg_views - ) - SELECT * - FROM list_views - WHERE - schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast') - AND schema_name NOT LIKE 'pg_temp_%' - AND ($1::text IS NULL OR view_name ILIKE '%' || $1::text || '%') - AND ($2::text IS NULL OR schema_name ILIKE '%' || $2::text || '%') - ORDER BY - schema_name, view_name - LIMIT COALESCE($3::int, 50); + WITH list_views AS ( + SELECT + schemaname AS schema_name, + viewname AS view_name, + viewowner AS owner_name, + definition + FROM pg_views + ) + SELECT * + FROM list_views + WHERE + schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast') + AND schema_name NOT LIKE 'pg_temp_%' + AND ($1::text IS NULL OR view_name ILIKE '%' || $1::text || '%') + AND ($2::text IS NULL OR schema_name ILIKE '%' || $2::text || '%') + ORDER BY + schema_name, view_name + LIMIT COALESCE($3::int, 50); ` func init() { @@ -76,7 +78,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -95,7 +96,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -108,7 +108,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -118,20 +117,24 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listViewsStatement, sliceParams) + resp, err := source.RunSQL(ctx, listViewsStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslongrunningtransactions/postgreslongrunningtransactions.go b/internal/tools/postgres/postgreslongrunningtransactions/postgreslongrunningtransactions.go index af35731d0c..2664c2e419 100644 --- a/internal/tools/postgres/postgreslongrunningtransactions/postgreslongrunningtransactions.go +++ b/internal/tools/postgres/postgreslongrunningtransactions/postgreslongrunningtransactions.go @@ -17,11 +17,13 @@ package postgreslongrunningtransactions import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,7 +31,7 @@ import ( const resourceType string = "postgres-long-running-transactions" const longRunningTransactions = ` - SELECT + SELECT pid, datname, usename, @@ -83,7 +85,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -103,7 +104,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -116,7 +116,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -130,20 +129,24 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, longRunningTransactions, sliceParams) + resp, err := source.RunSQL(ctx, longRunningTransactions, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgresreplicationstats/postgresreplicationstats.go b/internal/tools/postgres/postgresreplicationstats/postgresreplicationstats.go index bdb45cda4e..495c640140 100644 --- a/internal/tools/postgres/postgresreplicationstats/postgresreplicationstats.go +++ b/internal/tools/postgres/postgresreplicationstats/postgresreplicationstats.go @@ -17,11 +17,13 @@ package postgresreplicationstats import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,11 +31,11 @@ import ( const resourceType string = "postgres-replication-stats" const replicationStats = ` - SELECT - pid, - usename, + SELECT + pid, + usename, application_name, - backend_xmin, + backend_xmin, client_addr, state, sync_state, @@ -73,7 +75,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -90,7 +91,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -103,7 +103,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -117,20 +116,24 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, replicationStats, sliceParams) + resp, err := source.RunSQL(ctx, replicationStats, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgressql/postgressql.go b/internal/tools/postgres/postgressql/postgressql.go index adfc6e830c..ece775a356 100644 --- a/internal/tools/postgres/postgressql/postgressql.go +++ b/internal/tools/postgres/postgressql/postgressql.go @@ -17,11 +17,13 @@ package postgressql import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -58,7 +60,6 @@ type Config struct { TemplateParameters parameters.Parameters `yaml:"templateParameters"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -73,7 +74,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup t := Tool{ Config: cfg, AllParams: allParameters, @@ -83,7 +83,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return t, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -93,24 +92,28 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, newStatement, sliceParams) + resp, err := source.RunSQL(ctx, newStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/redis/redis.go b/internal/tools/redis/redis.go index e3d56a1596..3aa3154354 100644 --- a/internal/tools/redis/redis.go +++ b/internal/tools/redis/redis.go @@ -16,12 +16,14 @@ package redis import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" redissrc "github.com/googleapis/genai-toolbox/internal/sources/redis" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -84,17 +86,21 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } cmds, err := replaceCommandsParams(t.Commands, t.Parameters, params) if err != nil { - return nil, fmt.Errorf("error replacing commands' parameters: %s", err) + return nil, util.NewAgentError("error replacing commands' parameters", err) } - return source.RunCommand(ctx, cmds) + resp, err := source.RunCommand(ctx, cmds) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/serverlessspark/createbatch/tool.go b/internal/tools/serverlessspark/createbatch/tool.go index 899c25d11e..cbbbc4c920 100644 --- a/internal/tools/serverlessspark/createbatch/tool.go +++ b/internal/tools/serverlessspark/createbatch/tool.go @@ -17,11 +17,13 @@ package createbatch import ( "context" "fmt" + "net/http" dataprocpb "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/protobuf/proto" ) @@ -65,15 +67,18 @@ type Tool struct { Parameters parameters.Parameters } -func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } batch, err := t.Builder.BuildBatch(params) if err != nil { - return nil, fmt.Errorf("failed to build batch: %w", err) + if tbErr, ok := err.(util.ToolboxError); ok { + return nil, tbErr + } + return nil, util.NewAgentError("failed to build batch", err) } if t.RuntimeConfig != nil { @@ -92,11 +97,20 @@ func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, par } batch.RuntimeConfig.Version = version } - return source.CreateBatch(ctx, batch) + + resp, err := source.CreateBatch(ctx, batch) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { - return parameters.EmbedParams(ctx, t.Parameters, paramValues, embeddingModelsMap, nil) + newParamValues, err := parameters.EmbedParams(ctx, t.Parameters, paramValues, embeddingModelsMap, nil) + if err != nil { + return nil, util.NewClientServerError(fmt.Sprintf("error embedding parameters: %v", err), http.StatusInternalServerError, err) + } + return newParamValues, nil } func (t *Tool) Manifest() tools.Manifest { diff --git a/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go b/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go index d931bb81e0..6aeb901f73 100644 --- a/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go +++ b/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go @@ -17,6 +17,7 @@ package serverlesssparkcancelbatch import ( "context" "fmt" + "net/http" "strings" dataproc "cloud.google.com/go/dataproc/v2/apiv1" @@ -24,6 +25,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -99,20 +101,26 @@ type Tool struct { } // Invoke executes the tool's operation. -func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } + paramMap := params.AsMap() operation, ok := paramMap["operation"].(string) if !ok { - return nil, fmt.Errorf("missing required parameter: operation") + return nil, util.NewAgentError("missing required parameter: operation", nil) } if strings.Contains(operation, "/") { - return nil, fmt.Errorf("operation must be a short operation name without '/': %s", operation) + return nil, util.NewAgentError(fmt.Sprintf("operation must be a short operation name without '/': %s", operation), nil) } - return source.CancelOperation(ctx, operation) + + resp, err := source.CancelOperation(ctx, operation) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch.go b/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch.go index f038280a1f..f00772dadd 100644 --- a/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch.go +++ b/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch.go @@ -17,6 +17,7 @@ package serverlesssparkgetbatch import ( "context" "fmt" + "net/http" "strings" dataproc "cloud.google.com/go/dataproc/v2/apiv1" @@ -24,6 +25,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -99,20 +101,25 @@ type Tool struct { } // Invoke executes the tool's operation. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramMap := params.AsMap() name, ok := paramMap["name"].(string) if !ok { - return nil, fmt.Errorf("missing required parameter: name") + return nil, util.NewAgentError("missing required parameter: name", nil) } if strings.Contains(name, "/") { - return nil, fmt.Errorf("name must be a short batch name without '/': %s", name) + return nil, util.NewAgentError(fmt.Sprintf("name must be a short batch name without '/': %s", name), nil) } - return source.GetBatch(ctx, name) + + resp, err := source.GetBatch(ctx, name) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go b/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go index 64d56b01a7..0c820d4950 100644 --- a/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go +++ b/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go @@ -17,12 +17,14 @@ package serverlesssparklistbatches import ( "context" "fmt" + "net/http" dataproc "cloud.google.com/go/dataproc/v2/apiv1" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -100,23 +102,39 @@ type Tool struct { } // Invoke executes the tool's operation. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } + paramMap := params.AsMap() var pageSize *int if ps, ok := paramMap["pageSize"]; ok && ps != nil { - pageSizeV := ps.(int) + pageSizeV, ok := ps.(int) + if !ok { + // Handle float64 case if unmarshaled from JSON usually + if f, ok := ps.(float64); ok { + pageSizeV = int(f) + } else { + return nil, util.NewAgentError("pageSize must be an integer", nil) + } + } + if pageSizeV <= 0 { - return nil, fmt.Errorf("pageSize must be positive: %d", pageSizeV) + return nil, util.NewAgentError(fmt.Sprintf("pageSize must be positive: %d", pageSizeV), nil) } pageSize = &pageSizeV } + pt, _ := paramMap["pageToken"].(string) filter, _ := paramMap["filter"].(string) - return source.ListBatches(ctx, pageSize, pt, filter) + + resp, err := source.ListBatches(ctx, pageSize, pt, filter) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/singlestore/singlestoreexecutesql/singlestoreexecutesql.go b/internal/tools/singlestore/singlestoreexecutesql/singlestoreexecutesql.go index 8eb3c2dc6e..c10e0e375e 100644 --- a/internal/tools/singlestore/singlestoreexecutesql/singlestoreexecutesql.go +++ b/internal/tools/singlestore/singlestoreexecutesql/singlestoreexecutesql.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -98,25 +99,30 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the provided SQL query using the tool's database connection and returns the results. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() - sql, ok := paramsMap["sql"].(string) + sqlStr, ok := paramsMap["sql"].(string) if !ok { - return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to cast parameter 'sql' to string: %v", paramsMap["sql"]), nil) } // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } - logger.DebugContext(ctx, "executing `%s` tool query: %s", resourceType, sql) - return source.RunSQL(ctx, sql, nil) + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sqlStr)) + + resp, err := source.RunSQL(ctx, sqlStr, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/singlestore/singlestoresql/singlestoresql.go b/internal/tools/singlestore/singlestoresql/singlestoresql.go index 5984edc2a0..3350390c7d 100644 --- a/internal/tools/singlestore/singlestoresql/singlestoresql.go +++ b/internal/tools/singlestore/singlestoresql/singlestoresql.go @@ -18,11 +18,13 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -126,25 +128,29 @@ func (t Tool) ToConfig() tools.ToolConfig { // Returns: // - A slice of maps, where each map represents a row with column names as keys. // - An error if template resolution, parameter extraction, query execution, or result processing fails. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, newStatement, sliceParams) + resp, err := source.RunSQL(ctx, newStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/snowflake/snowflakeexecutesql/snowflakeexecutesql.go b/internal/tools/snowflake/snowflakeexecutesql/snowflakeexecutesql.go index e83a7912e4..6a85001d0d 100644 --- a/internal/tools/snowflake/snowflakeexecutesql/snowflakeexecutesql.go +++ b/internal/tools/snowflake/snowflakeexecutesql/snowflakeexecutesql.go @@ -17,6 +17,7 @@ package snowflakeexecutesql import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -89,26 +90,30 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } mapParams := params.AsMap() - sql, ok := mapParams["sql"].(string) + sqlStr, ok := mapParams["sql"].(string) if !ok { - return nil, fmt.Errorf("invalid parameters: sql parameter is not a string") + return nil, util.NewAgentError("invalid parameters: sql parameter is not a string", nil) } // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } - logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql)) - return source.RunSQL(ctx, sql, nil) + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sqlStr)) + resp, err := source.RunSQL(ctx, sqlStr, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/snowflake/snowflakesql/snowflakesql.go b/internal/tools/snowflake/snowflakesql/snowflakesql.go index a2eb670ea6..e5a9835d98 100644 --- a/internal/tools/snowflake/snowflakesql/snowflakesql.go +++ b/internal/tools/snowflake/snowflakesql/snowflakesql.go @@ -17,11 +17,13 @@ package snowflakesql import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jmoiron/sqlx" ) @@ -93,25 +95,29 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, newStatement, sliceParams) + resp, err := source.RunSQL(ctx, newStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/spanner/spannerexecutesql/spannerexecutesql.go b/internal/tools/spanner/spannerexecutesql/spannerexecutesql.go index f91d6579c0..94f5b1e7c5 100644 --- a/internal/tools/spanner/spannerexecutesql/spannerexecutesql.go +++ b/internal/tools/spanner/spannerexecutesql/spannerexecutesql.go @@ -17,6 +17,7 @@ package spannerexecutesql import ( "context" "fmt" + "net/http" "cloud.google.com/go/spanner" yaml "github.com/goccy/go-yaml" @@ -91,25 +92,29 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { - return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to get cast %s", paramsMap["sql"]), nil) } // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql)) - return source.RunSQL(ctx, t.ReadOnly, sql, nil) + resp, err := source.RunSQL(ctx, t.ReadOnly, sql, nil) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/spanner/spannerlistgraphs/spannerlistgraphs.go b/internal/tools/spanner/spannerlistgraphs/spannerlistgraphs.go index d4b7610421..ed8d74a08e 100644 --- a/internal/tools/spanner/spannerlistgraphs/spannerlistgraphs.go +++ b/internal/tools/spanner/spannerlistgraphs/spannerlistgraphs.go @@ -17,6 +17,7 @@ package spannerlistgraphs import ( "context" "fmt" + "net/http" "strings" "cloud.google.com/go/spanner" @@ -24,6 +25,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -105,15 +107,15 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } // Check dialect here at RUNTIME instead of startup if strings.ToLower(source.DatabaseDialect()) != "googlesql" { - return nil, fmt.Errorf("operation not supported: The 'spanner-list-graphs' tool is only available for GoogleSQL dialect databases. Your current database dialect is '%s'", source.DatabaseDialect()) + return nil, util.NewAgentError(fmt.Sprintf("operation not supported: The 'spanner-list-graphs' tool is only available for GoogleSQL dialect databases. Your current database dialect is '%s'", source.DatabaseDialect()), nil) } paramsMap := params.AsMap() @@ -128,7 +130,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para "graph_names": graphNames, "output_format": outputFormat, } - return source.RunSQL(ctx, true, googleSQLStatement, stmtParams) + resp, err := source.RunSQL(ctx, true, googleSQLStatement, stmtParams) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/spanner/spannerlisttables/spannerlisttables.go b/internal/tools/spanner/spannerlisttables/spannerlisttables.go index 0bb3048dba..f301183903 100644 --- a/internal/tools/spanner/spannerlisttables/spannerlisttables.go +++ b/internal/tools/spanner/spannerlisttables/spannerlisttables.go @@ -17,6 +17,7 @@ package spannerlisttables import ( "context" "fmt" + "net/http" "strings" "cloud.google.com/go/spanner" @@ -24,6 +25,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -117,10 +119,10 @@ func getStatement(dialect string) string { } } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() @@ -131,8 +133,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Prepare parameters based on dialect var stmtParams map[string]interface{} - tableNames, _ := paramsMap["table_names"].(string) - outputFormat, _ := paramsMap["output_format"].(string) + tableNames, ok := paramsMap["table_names"].(string) + if !ok { + return nil, util.NewAgentError("unable to get cast table_names", nil) + } + outputFormat, ok := paramsMap["output_format"].(string) + if !ok { + return nil, util.NewAgentError("unable to get cast output_format", nil) + } if outputFormat == "" { outputFormat = "detailed" } @@ -151,10 +159,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para "output_format": outputFormat, } default: - return nil, fmt.Errorf("unsupported dialect: %s", source.DatabaseDialect()) + return nil, util.NewAgentError(fmt.Sprintf("unsupported dialect: %s", source.DatabaseDialect()), nil) } - return source.RunSQL(ctx, true, statement, stmtParams) + resp, err := source.RunSQL(ctx, true, statement, stmtParams) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/spanner/spannersql/spannersql.go b/internal/tools/spanner/spannersql/spannersql.go index 5e11ae04aa..810d1d2d09 100644 --- a/internal/tools/spanner/spannersql/spannersql.go +++ b/internal/tools/spanner/spannersql/spannersql.go @@ -17,6 +17,7 @@ package spannersql import ( "context" "fmt" + "net/http" "strings" "cloud.google.com/go/spanner" @@ -24,6 +25,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -103,25 +105,25 @@ func getMapParams(params parameters.ParamValues, dialect string) (map[string]int case "postgresql": return params.AsMapByOrderedKeys(), nil default: - return nil, fmt.Errorf("invalid dialect %s", dialect) + return nil, util.NewAgentError(fmt.Sprintf("invalid dialect %s", dialect), nil) } } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("unable to extract template params: %v", err), http.StatusInternalServerError, err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewClientServerError(fmt.Sprintf("unable to extract standard params: %v", err), http.StatusInternalServerError, err) } for i, p := range t.Parameters { @@ -135,13 +137,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para case *parameters.ArrayParameter: arrayParamValue, ok := value.([]any) if !ok { - return nil, fmt.Errorf("unable to convert parameter `%s` to []any %w", name, err) + return nil, util.NewClientServerError(fmt.Sprintf("unable to convert parameter `%s` to []any", name), http.StatusInternalServerError, err) } itemType := arrayParam.GetItems().GetType() - var err error - value, err = parameters.ConvertAnySliceToTyped(arrayParamValue, itemType) - if err != nil { - return nil, fmt.Errorf("unable to convert parameter `%s` from []any to typed slice: %w", name, err) + var convertErr error + value, convertErr = parameters.ConvertAnySliceToTyped(arrayParamValue, itemType) + if convertErr != nil { + return nil, util.NewClientServerError(fmt.Sprintf("unable to convert parameter `%s` from []any to typed slice: %v", name, convertErr), http.StatusInternalServerError, convertErr) } } newParams[i] = parameters.ParamValue{Name: name, Value: value} @@ -149,9 +151,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para mapParams, err := getMapParams(newParams, source.DatabaseDialect()) if err != nil { - return nil, fmt.Errorf("fail to get map params: %w", err) + return nil, util.NewAgentError("fail to get map params", err) } - return source.RunSQL(ctx, t.ReadOnly, newStatement, mapParams) + + resp, err := source.RunSQL(ctx, t.ReadOnly, newStatement, mapParams) + if err != nil { + return nil, util.ProcessGcpError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql.go b/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql.go index fe2b287fa0..32ae860bfe 100644 --- a/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql.go +++ b/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -88,27 +89,32 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } - sql, ok := params.AsMap()["sql"].(string) + sqlStr, ok := params.AsMap()["sql"].(string) if !ok { - return nil, fmt.Errorf("missing or invalid 'sql' parameter") + return nil, util.NewAgentError("missing or invalid 'sql' parameter", nil) } - if sql == "" { - return nil, fmt.Errorf("sql parameter cannot be empty") + if sqlStr == "" { + return nil, util.NewAgentError("sql parameter cannot be empty", nil) } // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } - logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql)) - return source.RunSQL(ctx, sql, nil) + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sqlStr)) + + resp, err := source.RunSQL(ctx, sqlStr, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/sqlite/sqlitesql/sqlitesql.go b/internal/tools/sqlite/sqlitesql/sqlitesql.go index 0b1e72ba7d..9f9e06f499 100644 --- a/internal/tools/sqlite/sqlitesql/sqlitesql.go +++ b/internal/tools/sqlite/sqlitesql/sqlitesql.go @@ -18,11 +18,13 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -93,23 +95,27 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } - return source.RunSQL(ctx, newStatement, newParams.AsSlice()) + resp, err := source.RunSQL(ctx, newStatement, newParams.AsSlice()) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/tidb/tidbexecutesql/tidbexecutesql.go b/internal/tools/tidb/tidbexecutesql/tidbexecutesql.go index 81286223ab..8bb246ffb0 100644 --- a/internal/tools/tidb/tidbexecutesql/tidbexecutesql.go +++ b/internal/tools/tidb/tidbexecutesql/tidbexecutesql.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -89,25 +90,30 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() - sql, ok := paramsMap["sql"].(string) + sqlStr, ok := paramsMap["sql"].(string) if !ok { - return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to cast parameter 'sql' to string: %v", paramsMap["sql"]), nil) } // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } - logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql)) - return source.RunSQL(ctx, sql, nil) + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sqlStr)) + + resp, err := source.RunSQL(ctx, sqlStr, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/tidb/tidbsql/tidbsql.go b/internal/tools/tidb/tidbsql/tidbsql.go index dbeac8f64c..4e9abbc890 100644 --- a/internal/tools/tidb/tidbsql/tidbsql.go +++ b/internal/tools/tidb/tidbsql/tidbsql.go @@ -18,11 +18,13 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -93,25 +95,29 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source not compatible with this tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, newStatement, sliceParams) + res, err := source.RunSQL(ctx, newStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return res, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/tools.go b/internal/tools/tools.go index 5950eadd82..93f2654c85 100644 --- a/internal/tools/tools.go +++ b/internal/tools/tools.go @@ -17,6 +17,7 @@ package tools import ( "context" "fmt" + "net/http" "slices" "strings" @@ -80,13 +81,13 @@ type AccessToken string func (token AccessToken) ParseBearerToken() (string, error) { headerParts := strings.Split(string(token), " ") if len(headerParts) != 2 || strings.ToLower(headerParts[0]) != "bearer" { - return "", fmt.Errorf("authorization header must be in the format 'Bearer ': %w", util.ErrUnauthorized) + return "", util.NewClientServerError("authorization header must be in the format 'Bearer '", http.StatusUnauthorized, nil) } return headerParts[1], nil } type Tool interface { - Invoke(context.Context, SourceProvider, parameters.ParamValues, AccessToken) (any, error) + Invoke(context.Context, SourceProvider, parameters.ParamValues, AccessToken) (any, util.ToolboxError) EmbedParams(context.Context, parameters.ParamValues, map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) Manifest() Manifest McpManifest() McpManifest diff --git a/internal/tools/trino/trinoexecutesql/trinoexecutesql.go b/internal/tools/trino/trinoexecutesql/trinoexecutesql.go index 2275e402d7..2ea25f9e2e 100644 --- a/internal/tools/trino/trinoexecutesql/trinoexecutesql.go +++ b/internal/tools/trino/trinoexecutesql/trinoexecutesql.go @@ -18,11 +18,13 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -88,18 +90,22 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source not compatible with this tool", http.StatusInternalServerError, err) } sliceParams := params.AsSlice() sql, ok := sliceParams[0].(string) if !ok { - return nil, fmt.Errorf("unable to cast sql parameter: %v", sliceParams[0]) + return nil, util.NewAgentError("unable to cast the `sql` input parameter into string", nil) } - return source.RunSQL(ctx, sql, nil) + res, err := source.RunSQL(ctx, sql, []any{}) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return res, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/trino/trinosql/trinosql.go b/internal/tools/trino/trinosql/trinosql.go index edbd6f2d57..3641d22cf0 100644 --- a/internal/tools/trino/trinosql/trinosql.go +++ b/internal/tools/trino/trinosql/trinosql.go @@ -18,11 +18,13 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -93,23 +95,27 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source not compatible with this tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, newStatement, sliceParams) + res, err := source.RunSQL(ctx, newStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return res, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/utility/wait/wait.go b/internal/tools/utility/wait/wait.go index e6638da2fc..32e752d113 100644 --- a/internal/tools/utility/wait/wait.go +++ b/internal/tools/utility/wait/wait.go @@ -23,6 +23,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -81,17 +82,17 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { paramsMap := params.AsMap() durationStr, ok := paramsMap["duration"].(string) if !ok { - return nil, fmt.Errorf("duration parameter is not a string") + return nil, util.NewAgentError("duration parameter is not a string", nil) } totalDuration, err := time.ParseDuration(durationStr) if err != nil { - return nil, fmt.Errorf("invalid duration format: %w", err) + return nil, util.NewAgentError("invalid duration format", err) } time.Sleep(totalDuration) diff --git a/internal/tools/valkey/valkey.go b/internal/tools/valkey/valkey.go index 46be19f886..95c7674832 100644 --- a/internal/tools/valkey/valkey.go +++ b/internal/tools/valkey/valkey.go @@ -16,11 +16,13 @@ package valkey import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/valkey-io/valkey-go" ) @@ -84,18 +86,22 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source not compatible with this tool", http.StatusInternalServerError, nil) } // Replace parameters commands, err := replaceCommandsParams(t.Commands, t.Parameters, params) if err != nil { - return nil, fmt.Errorf("error replacing commands' parameters: %s", err) + return nil, util.NewAgentError("error replacing commands' parameters", err) } - return source.RunCommand(ctx, commands) + res, err := source.RunCommand(ctx, commands) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return res, nil } // replaceCommandsParams is a helper function to replace parameters in the commands diff --git a/internal/tools/yugabytedbsql/yugabytedbsql.go b/internal/tools/yugabytedbsql/yugabytedbsql.go index d97fd1dea2..6eb3f51f6c 100644 --- a/internal/tools/yugabytedbsql/yugabytedbsql.go +++ b/internal/tools/yugabytedbsql/yugabytedbsql.go @@ -17,11 +17,13 @@ package yugabytedbsql import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" embeddingmodels "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/yugabyte/pgx/v5/pgxpool" ) @@ -93,24 +95,28 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source not compatible with this tool", http.StatusInternalServerError, nil) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, newStatement, sliceParams) + res, err := source.RunSQL(ctx, newStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return res, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/util/errors.go b/internal/util/errors.go index 38dd7f5954..e8d5328bda 100644 --- a/internal/util/errors.go +++ b/internal/util/errors.go @@ -12,7 +12,14 @@ // limitations under the License. package util -import "fmt" +import ( + "errors" + "fmt" + "net/http" + "strings" + + "google.golang.org/api/googleapi" +) type ErrorCategory string @@ -52,6 +59,8 @@ func NewAgentError(msg string, cause error) *AgentError { return &AgentError{Msg: msg, Cause: cause} } +var _ ToolboxError = &AgentError{} + // ClientServerError returns 4XX/5XX error code type ClientServerError struct { Msg string @@ -75,3 +84,57 @@ func (e *ClientServerError) Unwrap() error { return e.Cause } func NewClientServerError(msg string, code int, cause error) *ClientServerError { return &ClientServerError{Msg: msg, Code: code, Cause: cause} } + +// ProcessGcpError catches auth related errors in GCP requests results and return 401/403 error codes +// Returns AgentError for all other errors +func ProcessGcpError(err error) ToolboxError { + var gErr *googleapi.Error + if errors.As(err, &gErr) { + if gErr.Code == 401 { + return NewClientServerError( + "failed to access GCP resource", + http.StatusUnauthorized, + err, + ) + } + if gErr.Code == 403 { + return NewClientServerError( + "failed to access GCP resource", + http.StatusForbidden, + err, + ) + } + } + return NewAgentError("error processing GCP request", err) +} + +// ProcessGeneralError handles generic errors by inspecting the error string +// for common status code patterns. +func ProcessGeneralError(err error) ToolboxError { + if err == nil { + return nil + } + + errStr := err.Error() + + // Check for Unauthorized + if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "status 401") { + return NewClientServerError( + "failed to access resource", + http.StatusUnauthorized, + err, + ) + } + + // Check for Forbidden + if strings.Contains(errStr, "Error 403") || strings.Contains(errStr, "status 403") { + return NewClientServerError( + "failed to access resource", + http.StatusForbidden, + err, + ) + } + + // Default to AgentError for logical failures (task execution failed) + return NewAgentError("error processing request", err) +} diff --git a/internal/util/parameters/parameters.go b/internal/util/parameters/parameters.go index f75da04a5d..7c991f61be 100644 --- a/internal/util/parameters/parameters.go +++ b/internal/util/parameters/parameters.go @@ -19,6 +19,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "reflect" "regexp" "slices" @@ -118,7 +119,7 @@ func parseFromAuthService(paramAuthServices []ParamAuthService, claimsMap map[st } return v, nil } - return nil, fmt.Errorf("missing or invalid authentication header: %w", util.ErrUnauthorized) + return nil, util.NewClientServerError("missing or invalid authentication header", http.StatusUnauthorized, nil) } // CheckParamRequired checks if a parameter is required based on the required and default field. @@ -147,20 +148,20 @@ func ParseParams(ps Parameters, data map[string]any, claimsMap map[string]map[st v = p.GetDefault() // if the parameter is required and no value given, throw an error if CheckParamRequired(p.GetRequired(), v) { - return nil, fmt.Errorf("parameter %q is required", name) + return nil, util.NewAgentError(fmt.Sprintf("parameter %q is required", name), nil) } } } else { // parse authenticated parameter v, err = parseFromAuthService(paramAuthServices, claimsMap) if err != nil { - return nil, fmt.Errorf("error parsing authenticated parameter %q: %w", name, err) + return nil, util.NewClientServerError(fmt.Sprintf("error parsing authenticated parameter %q", name), http.StatusUnauthorized, err) } } if v != nil { newV, err = p.Parse(v) if err != nil { - return nil, fmt.Errorf("unable to parse value for %q: %w", name, err) + return nil, util.NewAgentError(fmt.Sprintf("unable to parse value for %q", name), err) } } params = append(params, ParamValue{Name: name, Value: newV}) diff --git a/internal/util/util.go b/internal/util/util.go index 657fe8bf29..7ac50f6b6e 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -17,7 +17,6 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "io" "net/http" @@ -188,5 +187,3 @@ func InstrumentationFromContext(ctx context.Context) (*telemetry.Instrumentation } return nil, fmt.Errorf("unable to retrieve instrumentation") } - -var ErrUnauthorized = errors.New("unauthorized") diff --git a/tests/alloydb/alloydb_integration_test.go b/tests/alloydb/alloydb_integration_test.go index 52ad7731f3..0cad64ba74 100644 --- a/tests/alloydb/alloydb_integration_test.go +++ b/tests/alloydb/alloydb_integration_test.go @@ -402,7 +402,7 @@ func runAlloyDBListClustersTest(t *testing.T, vars map[string]string) { { name: "list clusters missing project", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"location": "%s"}`, vars["location"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, }, { name: "list clusters non-existent location", @@ -417,12 +417,12 @@ func runAlloyDBListClustersTest(t *testing.T, vars map[string]string) { { name: "list clusters empty project", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "", "location": "%s"}`, vars["location"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, }, { name: "list clusters empty location", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": ""}`, vars["project"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, }, } @@ -489,42 +489,42 @@ func runAlloyDBListUsersTest(t *testing.T, vars map[string]string) { requestBody io.Reader wantContains string wantStatusCode int + expectAgentErr bool }{ { name: "list users success", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "%s", "cluster": "%s"}`, vars["project"], vars["location"], vars["cluster"])), wantContains: fmt.Sprintf("projects/%s/locations/%s/clusters/%s/users/%s", vars["project"], vars["location"], vars["cluster"], AlloyDBUser), wantStatusCode: http.StatusOK, + expectAgentErr: false, }, { name: "list users missing project", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"location": "%s", "cluster": "%s"}`, vars["location"], vars["cluster"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, + wantContains: `parameter \"project\" is required`, + expectAgentErr: true, }, { name: "list users missing location", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "cluster": "%s"}`, vars["project"], vars["cluster"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, + wantContains: `parameter \"location\" is required`, + expectAgentErr: true, }, { name: "list users missing cluster", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "%s"}`, vars["project"], vars["cluster"])), - wantStatusCode: http.StatusBadRequest, - }, - { - name: "list users non-existent project", - requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "non-existent-project", "location": "%s", "cluster": "%s"}`, vars["location"], vars["cluster"])), - wantStatusCode: http.StatusInternalServerError, - }, - { - name: "list users non-existent location", - requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "non-existent-location", "cluster": "%s"}`, vars["project"], vars["cluster"])), - wantStatusCode: http.StatusInternalServerError, + wantStatusCode: http.StatusOK, + wantContains: `parameter \"cluster\" is required`, + expectAgentErr: true, }, { name: "list users non-existent cluster", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "%s", "cluster": "non-existent-cluster"}`, vars["project"], vars["location"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, + wantContains: `was not found`, + expectAgentErr: true, }, } @@ -544,7 +544,7 @@ func runAlloyDBListUsersTest(t *testing.T, vars map[string]string) { if resp.StatusCode != tc.wantStatusCode { bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("response status code is not %d, got %d: %s", tc.wantStatusCode, resp.StatusCode, string(bodyBytes)) + t.Fatalf("response status code: got %d, want %d: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes)) } if tc.wantStatusCode == http.StatusOK { @@ -553,27 +553,28 @@ func runAlloyDBListUsersTest(t *testing.T, vars map[string]string) { t.Fatalf("error parsing outer response body: %v", err) } - var usersData UsersResponse - if err := json.Unmarshal([]byte(body.Result), &usersData); err != nil { - t.Fatalf("error parsing nested result JSON: %v", err) - } - - var got []string - for _, user := range usersData.Users { - got = append(got, user.Name) - } - - sort.Strings(got) - - found := false - for _, g := range got { - if g == tc.wantContains { - found = true - break + if tc.expectAgentErr { + // Logic for checking wrapped error messages + if !strings.Contains(body.Result, tc.wantContains) { + t.Errorf("expected agent error message not found:\n got: %s\nwant: %s", body.Result, tc.wantContains) + } + } else { + // Logic for checking successful resource lists + var usersData UsersResponse + if err := json.Unmarshal([]byte(body.Result), &usersData); err != nil { + t.Fatalf("error parsing nested result JSON: %v. Result was: %s", err, body.Result) + } + + found := false + for _, user := range usersData.Users { + if user.Name == tc.wantContains { + found = true + break + } + } + if !found { + t.Errorf("expected user name %q not found in response", tc.wantContains) } - } - if !found { - t.Errorf("wantContains not found in response:\n got: %v\nwant: %v", got, tc.wantContains) } } }) @@ -636,7 +637,7 @@ func runAlloyDBListInstancesTest(t *testing.T, vars map[string]string) { { name: "list instances missing project", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"location": "%s", "cluster": "%s"}`, vars["location"], vars["cluster"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, }, { name: "list instances non-existent project", @@ -651,7 +652,7 @@ func runAlloyDBListInstancesTest(t *testing.T, vars map[string]string) { { name: "list instances non-existent cluster", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "%s", "cluster": "non-existent-cluster"}`, vars["project"], vars["location"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, }, } @@ -725,22 +726,22 @@ func runAlloyDBGetClusterTest(t *testing.T, vars map[string]string) { { name: "get cluster missing project", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"location": "%s", "cluster": "%s"}`, vars["location"], vars["cluster"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, }, { name: "get cluster missing location", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "cluster": "%s"}`, vars["project"], vars["cluster"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, }, { name: "get cluster missing cluster", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "%s"}`, vars["project"], vars["location"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, }, { name: "get cluster non-existent cluster", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "%s", "cluster": "non-existent-cluster"}`, vars["project"], vars["location"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, }, } @@ -815,27 +816,27 @@ func runAlloyDBGetInstanceTest(t *testing.T, vars map[string]string) { { name: "get instance missing project", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"location": "%s", "cluster": "%s", "instance": "%s"}`, vars["location"], vars["cluster"], vars["instance"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, }, { name: "get instance missing location", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "cluster": "%s", "instance": "%s"}`, vars["project"], vars["cluster"], vars["instance"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, }, { name: "get instance missing cluster", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "%s", "instance": "%s"}`, vars["project"], vars["location"], vars["instance"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, }, { name: "get instance missing instance", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "%s", "cluster": "%s"}`, vars["project"], vars["location"], vars["cluster"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, }, { name: "get instance non-existent instance", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "%s", "cluster": "%s", "instance": "non-existent-instance"}`, vars["project"], vars["location"], vars["cluster"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, }, } @@ -910,27 +911,27 @@ func runAlloyDBGetUserTest(t *testing.T, vars map[string]string) { { name: "get user missing project", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"location": "%s", "cluster": "%s", "user": "%s"}`, vars["location"], vars["cluster"], vars["user"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, }, { name: "get user missing location", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "cluster": "%s", "user": "%s"}`, vars["project"], vars["cluster"], vars["user"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, }, { name: "get user missing cluster", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "%s", "user": "%s"}`, vars["project"], vars["location"], vars["user"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, }, { name: "get user missing user", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "%s", "cluster": "%s"}`, vars["project"], vars["location"], vars["cluster"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, }, { name: "get non-existent user", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "%s", "cluster": "%s", "user": "non-existent-user"}`, vars["project"], vars["location"], vars["cluster"])), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, }, } @@ -1129,26 +1130,26 @@ func TestAlloyDBCreateCluster(t *testing.T) { { name: "api failure", body: `{"project": "p1", "location": "l1", "cluster": "c2-api-failure", "password": "p1"}`, - want: "internal api error", - wantStatusCode: http.StatusBadRequest, + want: `{"error":"error processing GCP request: error creating AlloyDB cluster: googleapi: Error 500: internal api error"}`, + wantStatusCode: http.StatusOK, }, { name: "missing project", body: `{"location": "l1", "cluster": "c1", "password": "p1"}`, - want: `parameter \"project\" is required`, - wantStatusCode: http.StatusBadRequest, + want: `{"error":"parameter \"project\" is required"}`, + wantStatusCode: http.StatusOK, }, { name: "missing cluster", body: `{"project": "p1", "location": "l1", "password": "p1"}`, - want: `parameter \"cluster\" is required`, - wantStatusCode: http.StatusBadRequest, + want: `{"error":"parameter \"cluster\" is required"}`, + wantStatusCode: http.StatusOK, }, { name: "missing password", body: `{"project": "p1", "location": "l1", "cluster": "c1"}`, - want: `parameter \"password\" is required`, - wantStatusCode: http.StatusBadRequest, + want: `{"error":"parameter \"password\" is required"}`, + wantStatusCode: http.StatusOK, }, } @@ -1239,38 +1240,38 @@ func TestAlloyDBCreateInstance(t *testing.T) { { name: "api failure", body: `{"project": "p1", "location": "l1", "cluster": "c1", "instance": "i2-api-failure", "instanceType": "PRIMARY", "displayName": "i1-success"}`, - want: "internal api error", - wantStatusCode: http.StatusBadRequest, + want: `{"error":"error processing GCP request: error creating AlloyDB instance: googleapi: Error 500: internal api error"}`, + wantStatusCode: http.StatusOK, }, { name: "missing project", body: `{"location": "l1", "cluster": "c1", "instance": "i1", "instanceType": "PRIMARY"}`, - want: `parameter \"project\" is required`, - wantStatusCode: http.StatusBadRequest, + want: `{"error":"parameter \"project\" is required"}`, + wantStatusCode: http.StatusOK, }, { name: "missing cluster", body: `{"project": "p1", "location": "l1", "instance": "i1", "instanceType": "PRIMARY"}`, - want: `parameter \"cluster\" is required`, - wantStatusCode: http.StatusBadRequest, + want: `{"error":"parameter \"cluster\" is required"}`, + wantStatusCode: http.StatusOK, }, { name: "missing location", body: `{"project": "p1", "cluster": "c1", "instance": "i1", "instanceType": "PRIMARY"}`, - want: `parameter \"location\" is required`, - wantStatusCode: http.StatusBadRequest, + want: `{"error":"parameter \"location\" is required"}`, + wantStatusCode: http.StatusOK, }, { name: "missing instance", body: `{"project": "p1", "location": "l1", "cluster": "c1", "instanceType": "PRIMARY"}`, - want: `parameter \"instance\" is required`, - wantStatusCode: http.StatusBadRequest, + want: `{"error":"parameter \"instance\" is required"}`, + wantStatusCode: http.StatusOK, }, { name: "invalid instanceType", body: `{"project": "p1", "location": "l1", "cluster": "c1", "instance": "i1", "instanceType": "INVALID", "displayName": "invalid"}`, - want: `invalid 'instanceType' parameter; expected 'PRIMARY' or 'READ_POOL'`, - wantStatusCode: http.StatusBadRequest, + want: `{"error":"invalid 'instanceType' parameter; expected 'PRIMARY' or 'READ_POOL'"}`, + wantStatusCode: http.StatusOK, }, } @@ -1371,50 +1372,50 @@ func TestAlloyDBCreateUser(t *testing.T) { { name: "api failure", body: `{"project": "p1", "location": "l1", "cluster": "c1", "user": "u3-api-failure", "userType": "ALLOYDB_IAM_USER"}`, - want: "user internal api error", - wantStatusCode: http.StatusBadRequest, + want: `{"error":"error processing GCP request: error creating AlloyDB user: googleapi: Error 500: user internal api error"}`, + wantStatusCode: http.StatusOK, }, { name: "missing project", body: `{"location": "l1", "cluster": "c1", "user": "u-fail", "userType": "ALLOYDB_IAM_USER"}`, - want: `parameter \"project\" is required`, - wantStatusCode: http.StatusBadRequest, + want: `{"error":"parameter \"project\" is required"}`, + wantStatusCode: http.StatusOK, }, { name: "missing cluster", body: `{"project": "p1", "location": "l1", "user": "u-fail", "userType": "ALLOYDB_IAM_USER"}`, - want: `parameter \"cluster\" is required`, - wantStatusCode: http.StatusBadRequest, + want: `{"error":"parameter \"cluster\" is required"}`, + wantStatusCode: http.StatusOK, }, { name: "missing location", body: `{"project": "p1", "cluster": "c1", "user": "u-fail", "userType": "ALLOYDB_IAM_USER"}`, - want: `parameter \"location\" is required`, - wantStatusCode: http.StatusBadRequest, + want: `{"error":"parameter \"location\" is required"}`, + wantStatusCode: http.StatusOK, }, { name: "missing user", body: `{"project": "p1", "location": "l1", "cluster": "c1", "userType": "ALLOYDB_IAM_USER"}`, - want: `parameter \"user\" is required`, - wantStatusCode: http.StatusBadRequest, + want: `{"error":"parameter \"user\" is required"}`, + wantStatusCode: http.StatusOK, }, { name: "missing userType", body: `{"project": "p1", "location": "l1", "cluster": "c1", "user": "u-fail"}`, - want: `parameter \"userType\" is required`, - wantStatusCode: http.StatusBadRequest, + want: `{"error":"parameter \"userType\" is required"}`, + wantStatusCode: http.StatusOK, }, { name: "missing password for builtin user", body: `{"project": "p1", "location": "l1", "cluster": "c1", "user": "u-fail", "userType": "ALLOYDB_BUILT_IN"}`, - want: `password is required when userType is ALLOYDB_BUILT_IN`, - wantStatusCode: http.StatusBadRequest, + want: `{"error":"password is required when userType is ALLOYDB_BUILT_IN"}`, + wantStatusCode: http.StatusOK, }, { name: "invalid userType", body: `{"project": "p1", "location": "l1", "cluster": "c1", "user": "u-fail", "userType": "invalid"}`, - want: `invalid or missing 'userType' parameter; expected 'ALLOYDB_BUILT_IN' or 'ALLOYDB_IAM_USER'`, - wantStatusCode: http.StatusBadRequest, + want: `{"error":"invalid or missing 'userType' parameter; expected 'ALLOYDB_BUILT_IN' or 'ALLOYDB_IAM_USER'"}`, + wantStatusCode: http.StatusOK, }, } diff --git a/tests/alloydb/alloydb_wait_for_operation_test.go b/tests/alloydb/alloydb_wait_for_operation_test.go index 38dece22d0..c82ab9e1c8 100644 --- a/tests/alloydb/alloydb_wait_for_operation_test.go +++ b/tests/alloydb/alloydb_wait_for_operation_test.go @@ -23,7 +23,6 @@ import ( "net/http" "net/http/httptest" "net/url" - "reflect" "regexp" "strings" "sync" @@ -165,13 +164,13 @@ func TestWaitToolEndpoints(t *testing.T) { name: "successful operation", toolName: "wait-for-op1", body: `{"project": "p1", "location": "l1", "operation": "op1"}`, - want: `{"name":"op1","done":true,"response":"success"}`, + want: `{"done":true,"name":"op1","response":"success"}`, }, { - name: "failed operation", - toolName: "wait-for-op2", - body: `{"project": "p1", "location": "l1", "operation": "op2"}`, - expectError: true, + name: "failed operation", + toolName: "wait-for-op2", + body: `{"project": "p1", "location": "l1", "operation": "op2"}`, + want: `{"error":"error processing request: operation finished with error: {\"code\":1,\"message\":\"failed\"}"}`, }, } @@ -189,48 +188,42 @@ func TestWaitToolEndpoints(t *testing.T) { } defer resp.Body.Close() - if tc.expectError { - if resp.StatusCode == http.StatusOK { - t.Fatal("expected error but got status 200") - } - return - } - if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) } - - var result struct { - Result string `json:"result"` + var response struct { + Result any `json:"result"` } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { t.Fatalf("failed to decode response: %v", err) } + var got string + // Check if the result is a string (which contains JSON) + if s, ok := response.Result.(string); ok { + got = s + } else { + b, err := json.Marshal(response.Result) + if err != nil { + t.Fatalf("failed to marshal result object: %v", err) + } + got = string(b) + } + + // Clean up both strings to ignore whitespace differences + got = strings.ReplaceAll(strings.ReplaceAll(got, " ", ""), "\n", "") + want := strings.ReplaceAll(strings.ReplaceAll(tc.want, " ", ""), "\n", "") + if tc.wantSubstring { - if !bytes.Contains([]byte(result.Result), []byte(tc.want)) { - t.Fatalf("unexpected result: got %q, want substring %q", result.Result, tc.want) + if !strings.Contains(got, want) { + t.Fatalf("unexpected result: got %q, want substring %q", got, want) } return } - // The result is a JSON-encoded string, so we need to unmarshal it twice. - var tempString string - if err := json.Unmarshal([]byte(result.Result), &tempString); err != nil { - t.Fatalf("failed to unmarshal result string: %v", err) - } - - var got, want map[string]any - if err := json.Unmarshal([]byte(tempString), &got); err != nil { - t.Fatalf("failed to unmarshal result: %v", err) - } - if err := json.Unmarshal([]byte(tc.want), &want); err != nil { - t.Fatalf("failed to unmarshal want: %v", err) - } - - if !reflect.DeepEqual(got, want) { - t.Fatalf("unexpected result: got %+v, want %+v", got, want) + if got != want { + t.Fatalf("unexpected result: \ngot: %s\nwant: %s", got, want) } }) } diff --git a/tests/bigquery/bigquery_integration_test.go b/tests/bigquery/bigquery_integration_test.go index 6059c190a8..30307296f1 100644 --- a/tests/bigquery/bigquery_integration_test.go +++ b/tests/bigquery/bigquery_integration_test.go @@ -175,7 +175,7 @@ func TestBigQueryToolEndpoints(t *testing.T) { ddlWant := `"Query executed successfully and returned no content."` dataInsightsWant := `(?s)Schema Resolved.*Retrieval Query.*SQL Generated.*Answer` // Partial message; the full error message is too long. - mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"query validation failed: failed to insert dry run job: googleapi: Error 400: Syntax error: Unexpected identifier \"SELEC\" at [1:1]` + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"error processing GCP request: failed to insert dry run job: googleapi: Error 400: Syntax error: Unexpected identifier \"SELEC\" at [1:1]` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"f0_\":1}"}]}}` createColArray := `["id INT64", "name STRING", "age INT64"]` selectEmptyWant := `"The query returned 0 rows."` @@ -954,7 +954,8 @@ func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamW api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{}`)), - isErr: true, + want: `{"error":"parameter \"sql\" is required"}`, + isErr: false, }, { name: "invoke my-exec-sql-tool", @@ -1009,6 +1010,7 @@ func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamW api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{}`)), + want: `{"error":"parameter \"sql\" is required"}`, isErr: true, }, { @@ -1161,12 +1163,11 @@ func runBigQueryWriteModeBlockedTest(t *testing.T, tableNameParam, datasetName s name string sql string wantStatusCode int - wantInError string wantResult string }{ - {"SELECT statement should succeed", fmt.Sprintf("SELECT id, name FROM %s WHERE id = 1", tableNameParam), http.StatusOK, "", `[{"id":1,"name":"Alice"}]`}, - {"INSERT statement should fail", fmt.Sprintf("INSERT INTO %s (id, name) VALUES (10, 'test')", tableNameParam), http.StatusBadRequest, "write mode is 'blocked', only SELECT statements are allowed", ""}, - {"CREATE TABLE statement should fail", fmt.Sprintf("CREATE TABLE %s.new_table (x INT64)", datasetName), http.StatusBadRequest, "write mode is 'blocked', only SELECT statements are allowed", ""}, + {"SELECT statement should succeed", fmt.Sprintf("SELECT id, name FROM %s WHERE id = 1", tableNameParam), http.StatusOK, `[{"id":1,"name":"Alice"}]`}, + {"INSERT statement should fail", fmt.Sprintf("INSERT INTO %s (id, name) VALUES (10, 'test')", tableNameParam), http.StatusOK, "{\"error\":\"write mode is 'blocked', only SELECT statements are allowed\"}"}, + {"CREATE TABLE statement should fail", fmt.Sprintf("CREATE TABLE %s.new_table (x INT64)", datasetName), http.StatusOK, "{\"error\":\"write mode is 'blocked', only SELECT statements are allowed\"}"}, } for _, tc := range testCases { @@ -1180,15 +1181,6 @@ func runBigQueryWriteModeBlockedTest(t *testing.T, tableNameParam, datasetName s t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes)) } - if tc.wantInError != "" { - errStr, ok := result["error"].(string) - if !ok { - t.Fatalf("expected 'error' field in response, got %v", result) - } - if !strings.Contains(errStr, tc.wantInError) { - t.Fatalf("expected error message to contain %q, but got %q", tc.wantInError, errStr) - } - } if tc.wantResult != "" { resStr, ok := result["result"].(string) if !ok { @@ -1215,9 +1207,9 @@ func runBigQueryWriteModeProtectedTest(t *testing.T, permanentDatasetName string name: "CREATE TABLE to permanent dataset should fail", toolName: "my-exec-sql-tool", requestBody: fmt.Sprintf(`{"sql": "CREATE TABLE %s.new_table (x INT64)"}`, permanentDatasetName), - wantStatusCode: http.StatusBadRequest, - wantInError: "protected write mode only supports SELECT statements, or write operations in the anonymous dataset", - wantResult: "", + wantStatusCode: http.StatusOK, + wantInError: "", + wantResult: "protected write mode only supports SELECT statements, or write operations in the anonymous dataset", }, { name: "CREATE TEMP TABLE should succeed", @@ -1709,7 +1701,8 @@ func runBigQueryDataTypeTests(t *testing.T) { api: "http://127.0.0.1:5000/api/tool/my-scalar-datatype-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{"int_val": 123}`)), - isErr: true, + want: `{"error":"parameter \"string_val\" is required"}`, + isErr: false, }, { name: "invoke my-array-datatype-tool", @@ -2578,7 +2571,7 @@ func runListTableIdsWithRestriction(t *testing.T, allowedDatasetName, disallowed { name: "invoke on disallowed dataset", dataset: disallowedDatasetName, - wantStatusCode: http.StatusBadRequest, // Or the specific error code returned + wantStatusCode: http.StatusOK, wantInError: fmt.Sprintf("access denied to dataset '%s'", disallowedDatasetName), }, } @@ -2652,7 +2645,7 @@ func runGetDatasetInfoWithRestriction(t *testing.T, allowedDatasetName, disallow { name: "invoke on disallowed dataset", dataset: disallowedDatasetName, - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, wantInError: fmt.Sprintf("access denied to dataset '%s'", disallowedDatasetName), }, } @@ -2704,8 +2697,7 @@ func runGetTableInfoWithRestriction(t *testing.T, allowedDatasetName, disallowed name: "invoke on disallowed table", dataset: disallowedDatasetName, table: disallowedTableName, - wantStatusCode: http.StatusBadRequest, - wantInError: fmt.Sprintf("access denied to dataset '%s'", disallowedDatasetName), + wantStatusCode: http.StatusOK, }, } @@ -2759,7 +2751,7 @@ func runExecuteSqlWithRestriction(t *testing.T, allowedTableFullName, disallowed { name: "invoke on disallowed table", sql: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, wantInError: fmt.Sprintf("query accesses dataset '%s', which is not in the allowed list", strings.Join( strings.Split(strings.Trim(disallowedTableFullName, "`"), ".")[0:2], @@ -2768,31 +2760,31 @@ func runExecuteSqlWithRestriction(t *testing.T, allowedTableFullName, disallowed { name: "disallowed create schema", sql: "CREATE SCHEMA another_dataset", - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, wantInError: "dataset-level operations like 'CREATE_SCHEMA' are not allowed", }, { name: "disallowed alter schema", sql: fmt.Sprintf("ALTER SCHEMA %s SET OPTIONS(description='new one')", allowedDatasetID), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, wantInError: "dataset-level operations like 'ALTER_SCHEMA' are not allowed", }, { name: "disallowed create function", sql: fmt.Sprintf("CREATE FUNCTION %s.my_func() RETURNS INT64 AS (1)", allowedDatasetID), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, wantInError: "creating stored routines ('CREATE_FUNCTION') is not allowed", }, { name: "disallowed create procedure", sql: fmt.Sprintf("CREATE PROCEDURE %s.my_proc() BEGIN SELECT 1; END", allowedDatasetID), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, wantInError: "unanalyzable statements like 'CREATE PROCEDURE' are not allowed", }, { name: "disallowed execute immediate", sql: "EXECUTE IMMEDIATE 'SELECT 1'", - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, wantInError: "EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place", }, } @@ -2846,7 +2838,7 @@ func runConversationalAnalyticsWithRestriction(t *testing.T, allowedDatasetName, { name: "invoke with disallowed table", tableRefs: disallowedTableRefsJSON, - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, wantInError: fmt.Sprintf("access to dataset '%s.%s' (from table '%s') is not allowed", BigqueryProject, disallowedDatasetName, disallowedTableName), }, } @@ -3030,12 +3022,24 @@ func runBigQuerySearchCatalogToolInvokeTest(t *testing.T, datasetName string, ta } t.Fatalf("expected 'result' field to be a string, got %T", result["result"]) } + + var errorCheck map[string]any + if err := json.Unmarshal([]byte(resultStr), &errorCheck); err == nil { + if _, hasError := errorCheck["error"]; hasError { + if tc.isErr { + return + } + t.Fatalf("unexpected error object in result: %s", resultStr) + } + } + if tc.isErr && (resultStr == "" || resultStr == "[]") { return } - var entries []interface{} + + var entries []any if err := json.Unmarshal([]byte(resultStr), &entries); err != nil { - t.Fatalf("error unmarshalling result string: %v", err) + t.Fatalf("error unmarshalling result string: %v. Raw string: %s", err, resultStr) } if !tc.isErr { @@ -3083,7 +3087,7 @@ func runForecastWithRestriction(t *testing.T, allowedTableFullName, disallowedTa { name: "invoke with disallowed table name", historyData: disallowedTableUnquoted, - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, wantInError: fmt.Sprintf("access to dataset '%s' (from table '%s') is not allowed", disallowedDatasetFQN, disallowedTableUnquoted), }, { @@ -3095,7 +3099,7 @@ func runForecastWithRestriction(t *testing.T, allowedTableFullName, disallowedTa { name: "invoke with query on disallowed table", historyData: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, wantInError: fmt.Sprintf("query in history_data accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN), }, } @@ -3174,8 +3178,8 @@ func runAnalyzeContributionWithRestriction(t *testing.T, allowedTableFullName, d { name: "invoke with disallowed table name", inputData: disallowedTableUnquoted, - wantStatusCode: http.StatusBadRequest, - wantInError: fmt.Sprintf("access to dataset '%s' (from table '%s') is not allowed", disallowedDatasetFQN, disallowedTableUnquoted), + wantStatusCode: http.StatusOK, + wantInResult: fmt.Sprintf("access to dataset '%s' (from table '%s') is not allowed", disallowedDatasetFQN, disallowedTableUnquoted), }, { name: "invoke with query on allowed table", @@ -3186,8 +3190,8 @@ func runAnalyzeContributionWithRestriction(t *testing.T, allowedTableFullName, d { name: "invoke with query on disallowed table", inputData: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName), - wantStatusCode: http.StatusBadRequest, - wantInError: fmt.Sprintf("query in input_data accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN), + wantStatusCode: http.StatusOK, + wantInResult: fmt.Sprintf("query in input_data accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN), }, } diff --git a/tests/bigtable/bigtable_integration_test.go b/tests/bigtable/bigtable_integration_test.go index 49a7ca69c3..d2ee4cad09 100644 --- a/tests/bigtable/bigtable_integration_test.go +++ b/tests/bigtable/bigtable_integration_test.go @@ -120,7 +120,7 @@ func TestBigtableToolEndpoints(t *testing.T) { // Actual test parameters are set in https://github.com/googleapis/genai-toolbox/blob/52b09a67cb40ac0c5f461598b4673136699a3089/tests/tool_test.go#L250 select1Want := "[{\"$col1\":1}]" myToolById4Want := `[{"id":4,"name":""}]` - mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to prepare statement: rpc error: code = InvalidArgument desc = Syntax error: Unexpected identifier \"SELEC\" [at 1:1]"}],"isError":true}}` + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"error processing GCP request: unable to prepare statement: rpc error: code = InvalidArgument desc = Syntax error: Unexpected identifier \"SELEC\" [at 1:1]"}],"isError":true}}` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"$col1\":1}"}]}}` nameFieldArray := `["CAST(cf['name'] AS string) as name"]` nameColFilter := "CAST(cf['name'] AS string)" diff --git a/tests/cassandra/cassandra_integration_test.go b/tests/cassandra/cassandra_integration_test.go index e1faac4554..2f833999cc 100644 --- a/tests/cassandra/cassandra_integration_test.go +++ b/tests/cassandra/cassandra_integration_test.go @@ -271,7 +271,7 @@ func getCassandraWants() (string, string, string, string, string, string) { selectIdNameWant := "[{\"id\":3,\"name\":\"Alice\"}]" selectIdNullWant := "[{\"id\":4,\"name\":\"\"}]" selectArrayParamWant := "[{\"id\":1,\"name\":\"Sid\"},{\"id\":3,\"name\":\"Alice\"}]" - mcpMyFailToolWant := "{\"jsonrpc\":\"2.0\",\"id\":\"invoke-fail-tool\",\"result\":{\"content\":[{\"type\":\"text\",\"text\":\"unable to parse rows: line 1:0 no viable alternative at input 'SELEC' ([SELEC]...)\"}],\"isError\":true}}" + mcpMyFailToolWant := "{\"jsonrpc\":\"2.0\",\"id\":\"invoke-fail-tool\",\"result\":{\"content\":[{\"type\":\"text\",\"text\":\"error processing request: unable to parse rows: line 1:0 no viable alternative at input 'SELEC' ([SELEC]...)\"}],\"isError\":true}}" mcpMyToolIdWant := "{\"jsonrpc\":\"2.0\",\"id\":\"my-tool\",\"result\":{\"content\":[{\"type\":\"text\",\"text\":\"[{\\\"id\\\":3,\\\"name\\\":\\\"Alice\\\"}]\"}]}}" return selectIdNameWant, selectIdNullWant, selectArrayParamWant, mcpMyFailToolWant, "nil", mcpMyToolIdWant } diff --git a/tests/clickhouse/clickhouse_integration_test.go b/tests/clickhouse/clickhouse_integration_test.go index 911bfdde11..6b0ae8d961 100644 --- a/tests/clickhouse/clickhouse_integration_test.go +++ b/tests/clickhouse/clickhouse_integration_test.go @@ -339,7 +339,7 @@ func TestClickHouseBasicConnection(t *testing.T) { func getClickHouseWants() (string, string, string, string, string) { select1Want := "[{\"1\":1}]" mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":1}"}]}}` - mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: sendQuery: [HTTP 400] response body: \"Code: 62. DB::Exception: Syntax error: failed at position 1 (SELEC): SELEC 1;. Expected one of: Query, Query with output, EXPLAIN, EXPLAIN, SELECT query, possibly with UNION, list of union elements, SELECT query, subquery, possibly with UNION, SELECT subquery, SELECT query, WITH, FROM, SELECT, SHOW CREATE QUOTA query, SHOW CREATE, SHOW [FULL] [TEMPORARY] TABLES|DATABASES|CLUSTERS|CLUSTER|MERGES 'name' [[NOT] [I]LIKE 'str'] [LIMIT expr], SHOW, SHOW COLUMNS query, SHOW ENGINES query, SHOW ENGINES, SHOW FUNCTIONS query, SHOW FUNCTIONS, SHOW INDEXES query, SHOW SETTING query, SHOW SETTING, EXISTS or SHOW CREATE query, EXISTS, DESCRIBE FILESYSTEM CACHE query, DESCRIBE, DESC, DESCRIBE query, SHOW PROCESSLIST query, SHOW PROCESSLIST, CREATE TABLE or ATTACH TABLE query, CREATE, ATTACH, REPLACE, CREATE DATABASE query, CREATE VIEW query, CREATE DICTIONARY, CREATE LIVE VIEW query, CREATE WINDOW VIEW query, ALTER query, ALTER TABLE, ALTER TEMPORARY TABLE, ALTER DATABASE, RENAME query, RENAME DATABASE, RENAME TABLE, EXCHANGE TABLES, RENAME DICTIONARY, EXCHANGE DICTIONARIES, RENAME, DROP query, DROP, DETACH, TRUNCATE, UNDROP query, UNDROP, CHECK ALL TABLES, CHECK TABLE, KILL QUERY query, KILL, OPTIMIZE query, OPTIMIZE TABLE, WATCH query, WATCH, SHOW ACCESS query, SHOW ACCESS, ShowAccessEntitiesQuery, SHOW GRANTS query, SHOW GRANTS, SHOW PRIVILEGES query, SHOW PRIVILEGES, BACKUP or RESTORE query, BACKUP, RESTORE, INSERT query, INSERT INTO, USE query, USE, SET ROLE or SET DEFAULT ROLE query, SET ROLE DEFAULT, SET ROLE, SET DEFAULT ROLE, SET query, SET, SYSTEM query, SYSTEM, CREATE USER or ALTER USER query, ALTER USER, CREATE USER, CREATE ROLE or ALTER ROLE query, ALTER ROLE, CREATE ROLE, CREATE QUOTA or ALTER QUOTA query, ALTER QUOTA, CREATE QUOTA, CREATE ROW POLICY or ALTER ROW POLICY query, ALTER POLICY, ALTER ROW POLICY, CREATE POLICY, CREATE ROW POLICY, CREATE SETTINGS PROFILE or ALTER SETTINGS PROFILE query, ALTER SETTINGS PROFILE, ALTER PROFILE, CREATE SETTINGS PROFILE, CREATE PROFILE, CREATE FUNCTION query, DROP FUNCTION query, CREATE WORKLOAD query, DROP WORKLOAD query, CREATE RESOURCE query, DROP RESOURCE query, CREATE NAMED COLLECTION, DROP NAMED COLLECTION query, Alter NAMED COLLECTION query, ALTER, CREATE INDEX query, DROP INDEX query, DROP access entity query, MOVE access entity query, MOVE, GRANT or REVOKE query, REVOKE, GRANT, CHECK GRANT, CHECK GRANT, EXTERNAL DDL query, EXTERNAL DDL FROM, TCL query, BEGIN TRANSACTION, START TRANSACTION, COMMIT, ROLLBACK, SET TRANSACTION SNAPSHOT, Delete query, DELETE, Update query, UPDATE. (SYNTAX_ERROR) (version 25.7.5.34 (official build))\n\""}],"isError":true}}` + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"error processing request: unable to execute query: sendQuery: [HTTP 400] response body: \"Code: 62. DB::Exception: Syntax error: failed at position 1 (SELEC): SELEC 1;. Expected one of: Query, Query with output, EXPLAIN, EXPLAIN, SELECT query, possibly with UNION, list of union elements, SELECT query, subquery, possibly with UNION, SELECT subquery, SELECT query, WITH, FROM, SELECT, SHOW CREATE QUOTA query, SHOW CREATE, SHOW [FULL] [TEMPORARY] TABLES|DATABASES|CLUSTERS|CLUSTER|MERGES 'name' [[NOT] [I]LIKE 'str'] [LIMIT expr], SHOW, SHOW COLUMNS query, SHOW ENGINES query, SHOW ENGINES, SHOW FUNCTIONS query, SHOW FUNCTIONS, SHOW INDEXES query, SHOW SETTING query, SHOW SETTING, EXISTS or SHOW CREATE query, EXISTS, DESCRIBE FILESYSTEM CACHE query, DESCRIBE, DESC, DESCRIBE query, SHOW PROCESSLIST query, SHOW PROCESSLIST, CREATE TABLE or ATTACH TABLE query, CREATE, ATTACH, REPLACE, CREATE DATABASE query, CREATE VIEW query, CREATE DICTIONARY, CREATE LIVE VIEW query, CREATE WINDOW VIEW query, ALTER query, ALTER TABLE, ALTER TEMPORARY TABLE, ALTER DATABASE, RENAME query, RENAME DATABASE, RENAME TABLE, EXCHANGE TABLES, RENAME DICTIONARY, EXCHANGE DICTIONARIES, RENAME, DROP query, DROP, DETACH, TRUNCATE, UNDROP query, UNDROP, CHECK ALL TABLES, CHECK TABLE, KILL QUERY query, KILL, OPTIMIZE query, OPTIMIZE TABLE, WATCH query, WATCH, SHOW ACCESS query, SHOW ACCESS, ShowAccessEntitiesQuery, SHOW GRANTS query, SHOW GRANTS, SHOW PRIVILEGES query, SHOW PRIVILEGES, BACKUP or RESTORE query, BACKUP, RESTORE, INSERT query, INSERT INTO, USE query, USE, SET ROLE or SET DEFAULT ROLE query, SET ROLE DEFAULT, SET ROLE, SET DEFAULT ROLE, SET query, SET, SYSTEM query, SYSTEM, CREATE USER or ALTER USER query, ALTER USER, CREATE USER, CREATE ROLE or ALTER ROLE query, ALTER ROLE, CREATE ROLE, CREATE QUOTA or ALTER QUOTA query, ALTER QUOTA, CREATE QUOTA, CREATE ROW POLICY or ALTER ROW POLICY query, ALTER POLICY, ALTER ROW POLICY, CREATE POLICY, CREATE ROW POLICY, CREATE SETTINGS PROFILE or ALTER SETTINGS PROFILE query, ALTER SETTINGS PROFILE, ALTER PROFILE, CREATE SETTINGS PROFILE, CREATE PROFILE, CREATE FUNCTION query, DROP FUNCTION query, CREATE WORKLOAD query, DROP WORKLOAD query, CREATE RESOURCE query, DROP RESOURCE query, CREATE NAMED COLLECTION, DROP NAMED COLLECTION query, Alter NAMED COLLECTION query, ALTER, CREATE INDEX query, DROP INDEX query, DROP access entity query, MOVE access entity query, MOVE, GRANT or REVOKE query, REVOKE, GRANT, CHECK GRANT, CHECK GRANT, EXTERNAL DDL query, EXTERNAL DDL FROM, TCL query, BEGIN TRANSACTION, START TRANSACTION, COMMIT, ROLLBACK, SET TRANSACTION SNAPSHOT, Delete query, DELETE, Update query, UPDATE. (SYNTAX_ERROR) (version 25.7.5.34 (official build))\n\""}],"isError":true}}` createTableStatement := `"CREATE TABLE t (id UInt32, name String) ENGINE = Memory"` nullWant := `[{"id":4,"name":""}]` return select1Want, mcpSelect1Want, mcpMyFailToolWant, createTableStatement, nullWant @@ -548,6 +548,7 @@ func TestClickHouseExecuteSQLTool(t *testing.T) { sql string resultSliceLen int isErr bool + isAgentErr bool }{ { name: "CreateTable", @@ -570,15 +571,15 @@ func TestClickHouseExecuteSQLTool(t *testing.T) { resultSliceLen: 0, }, { - name: "MissingSQL", - sql: "", - isErr: true, + name: "MissingSQL", + sql: "", + isAgentErr: true, }, { - name: "SQLInjectionAttempt", - sql: "SELECT 1; DROP TABLE system.users; SELECT 2", - isErr: true, + name: "SQLInjectionAttempt", + sql: "SELECT 1; DROP TABLE system.users; SELECT 2", + isAgentErr: true, }, } for _, tc := range tcs { @@ -595,6 +596,9 @@ func TestClickHouseExecuteSQLTool(t *testing.T) { if tc.isErr { t.Fatalf("expecting an error from server") } + if tc.isAgentErr { + return + } var body map[string]interface{} err := json.Unmarshal(respBody, &body) @@ -1119,16 +1123,16 @@ func TestClickHouseListTablesTool(t *testing.T) { t.Run("ListTablesWithMissingDatabase", func(t *testing.T) { api := "http://127.0.0.1:5000/api/tool/test-list-tables/invoke" resp, _ := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer([]byte(`{}`)), nil) - if resp.StatusCode == http.StatusOK { - t.Error("Expected error for missing database parameter, but got 200 OK") + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected 200 OK for missing database parameter, but got %d", resp.StatusCode) } }) t.Run("ListTablesWithInvalidSource", func(t *testing.T) { api := "http://127.0.0.1:5000/api/tool/test-invalid-source/invoke" resp, _ := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer([]byte(`{}`)), nil) - if resp.StatusCode == http.StatusOK { - t.Error("Expected error for non-existent source, but got 200 OK") + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected 200 OK for non-existent source, but got %d", resp.StatusCode) } }) diff --git a/tests/cloudhealthcare/cloud_healthcare_integration_test.go b/tests/cloudhealthcare/cloud_healthcare_integration_test.go index 72dac07928..4ffeee85d9 100644 --- a/tests/cloudhealthcare/cloud_healthcare_integration_test.go +++ b/tests/cloudhealthcare/cloud_healthcare_integration_test.go @@ -717,8 +717,8 @@ func runGetDatasetToolInvokeTest(t *testing.T, want string) { t.Run(tc.name, func(t *testing.T) { got, status := runTest(t, tc.api, tc.requestHeader, tc.requestBody) if tc.isErr { - if status == http.StatusOK { - t.Errorf("expected error but got success") + if status == http.StatusOK && !strings.Contains(got, "error") { + t.Errorf("expected error but got success: %s", got) } return } @@ -915,8 +915,8 @@ func runListDICOMStoresToolInvokeTest(t *testing.T, want string) { t.Run(tc.name, func(t *testing.T) { got, status := runTest(t, tc.api, tc.requestHeader, tc.requestBody) if tc.isErr { - if status == http.StatusOK { - t.Errorf("expected error but got success") + if status == http.StatusOK && !strings.Contains(got, "error") { + t.Errorf("expected error but got success: %s", got) } return } @@ -1021,8 +1021,8 @@ func runGetFHIRStoreToolInvokeTest(t *testing.T, fhirStoreID, want string) { t.Run(tc.name, func(t *testing.T) { got, status := runTest(t, tc.api, tc.requestHeader, tc.requestBody) if tc.isErr { - if status == http.StatusOK { - t.Errorf("expected error but got success") + if status == http.StatusOK && !strings.Contains(got, "error") { + t.Errorf("expected error but got success: %s", got) } return } @@ -1127,8 +1127,8 @@ func runGetFHIRStoreMetricsToolInvokeTest(t *testing.T, fhirStoreID, want string t.Run(tc.name, func(t *testing.T) { got, status := runTest(t, tc.api, tc.requestHeader, tc.requestBody) if tc.isErr { - if status == http.StatusOK { - t.Errorf("expected error but got success") + if status == http.StatusOK && !strings.Contains(got, "error") { + t.Errorf("expected error but got success: %s", got) } return } @@ -1240,8 +1240,8 @@ func runGetFHIRResourceToolInvokeTest(t *testing.T, storeID, resType, resID, wan t.Run(tc.name, func(t *testing.T) { got, status := runTest(t, tc.api, tc.requestHeader, tc.requestBody) if tc.isErr { - if status == http.StatusOK { - t.Errorf("expected error but got success") + if status == http.StatusOK && !strings.Contains(got, "error") { + t.Errorf("expected error but got success: %s", got) } return } @@ -1394,8 +1394,8 @@ func runFHIRPatientSearchToolInvokeTest(t *testing.T, fhirStoreID string, patien t.Run(tc.name, func(t *testing.T) { got, status := runTest(t, tc.api, tc.requestHeader, tc.requestBody) if tc.isErr { - if status == http.StatusOK { - t.Errorf("expected error but got success") + if status == http.StatusOK && !strings.Contains(got, "error") { + t.Errorf("expected error but got success: %s", got) } return } @@ -1531,8 +1531,8 @@ func runFHIRPatientEverythingToolInvokeTest(t *testing.T, fhirStoreID, patientID t.Run(tc.name, func(t *testing.T) { got, status := runTest(t, tc.api, tc.requestHeader, tc.requestBody) if tc.isErr { - if status == http.StatusOK { - t.Errorf("expected error but got success") + if status == http.StatusOK && !strings.Contains(got, "error") { + t.Errorf("expected error but got success: %s", got) } return } @@ -1637,8 +1637,8 @@ func runFHIRFetchPageToolInvokeTest(t *testing.T, pageURL, want string) { t.Run(tc.name, func(t *testing.T) { got, status := runTest(t, tc.api, tc.requestHeader, tc.requestBody) if tc.isErr { - if status == http.StatusOK { - t.Errorf("expected error but got success") + if status == http.StatusOK && !strings.Contains(got, "error") { + t.Errorf("expected error but got success: %s", got) } return } @@ -1710,6 +1710,9 @@ func runTest(t *testing.T, api string, requestHeader map[string]string, requestB got, ok := body["result"].(string) if !ok { + if errMsg, ok := body["error"].(string); ok { + return errMsg, http.StatusOK + } t.Fatalf("unable to find result in response body") } return got, http.StatusOK @@ -1837,8 +1840,8 @@ func runGetDICOMStoreToolInvokeTest(t *testing.T, dicomStoreID, want string) { t.Run(tc.name, func(t *testing.T) { got, status := runTest(t, tc.api, tc.requestHeader, tc.requestBody) if tc.isErr { - if status == http.StatusOK { - t.Errorf("expected error but got success") + if status == http.StatusOK && !strings.Contains(got, "error") { + t.Errorf("expected error but got success: %s", got) } return } @@ -1943,8 +1946,8 @@ func runGetDICOMStoreMetricsToolInvokeTest(t *testing.T, dicomStoreID, want stri t.Run(tc.name, func(t *testing.T) { got, status := runTest(t, tc.api, tc.requestHeader, tc.requestBody) if tc.isErr { - if status == http.StatusOK { - t.Errorf("expected error but got success") + if status == http.StatusOK && !strings.Contains(got, "error") { + t.Errorf("expected error but got success: %s", got) } return } @@ -2065,8 +2068,8 @@ func runSearchDICOMStudiesToolInvokeTest(t *testing.T, dicomStoreID string) { t.Run(tc.name, func(t *testing.T) { got, status := runTest(t, tc.api, tc.requestHeader, tc.requestBody) if tc.isErr { - if status == http.StatusOK { - t.Errorf("expected error but got success") + if status == http.StatusOK && !strings.Contains(got, "error") { + t.Errorf("expected error but got success: %s", got) } return } @@ -2187,8 +2190,8 @@ func runSearchDICOMSeriesToolInvokeTest(t *testing.T, dicomStoreID string) { t.Run(tc.name, func(t *testing.T) { got, status := runTest(t, tc.api, tc.requestHeader, tc.requestBody) if tc.isErr { - if status == http.StatusOK { - t.Errorf("expected error but got success") + if status == http.StatusOK && !strings.Contains(got, "error") { + t.Errorf("expected error but got success: %s", got) } return } @@ -2309,8 +2312,8 @@ func runSearchDICOMInstancesToolInvokeTest(t *testing.T, dicomStoreID string) { t.Run(tc.name, func(t *testing.T) { got, status := runTest(t, tc.api, tc.requestHeader, tc.requestBody) if tc.isErr { - if status == http.StatusOK { - t.Errorf("expected error but got success") + if status == http.StatusOK && !strings.Contains(got, "error") { + t.Errorf("expected error but got success: %s", got) } return } @@ -2422,10 +2425,10 @@ func runRetrieveRenderedDICOMInstanceToolInvokeTest(t *testing.T, dicomStoreID s } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - _, status := runTest(t, tc.api, tc.requestHeader, tc.requestBody) + got, status := runTest(t, tc.api, tc.requestHeader, tc.requestBody) if tc.isErr { - if status == http.StatusOK { - t.Errorf("expected error but got success") + if status == http.StatusOK && !strings.Contains(got, "error") { + t.Errorf("expected error but got success: %s", got) } return } diff --git a/tests/cloudloggingadmin/cloud_logging_admin_integration_test.go b/tests/cloudloggingadmin/cloud_logging_admin_integration_test.go index 92cbb8fe32..68c3227621 100644 --- a/tests/cloudloggingadmin/cloud_logging_admin_integration_test.go +++ b/tests/cloudloggingadmin/cloud_logging_admin_integration_test.go @@ -332,8 +332,8 @@ func runQueryLogsErrorTest(t *testing.T) { t.Run("query-logs-error", func(t *testing.T) { requestBody := `{"filter": "INVALID_FILTER_SYNTAX :::", "limit": 10}` resp, _ := tests.RunRequest(t, http.MethodPost, "http://127.0.0.1:5000/api/tool/query-logs/invoke", bytes.NewBuffer([]byte(requestBody)), nil) - if resp.StatusCode == 200 { - t.Errorf("expected error status code, got 200 OK") + if resp.StatusCode != 200 { + t.Errorf("expected 200 OK") } }) } diff --git a/tests/cloudsql/cloud_sql_clone_instance_test.go b/tests/cloudsql/cloud_sql_clone_instance_test.go index 024c24d153..ac504b2a98 100644 --- a/tests/cloudsql/cloud_sql_clone_instance_test.go +++ b/tests/cloudsql/cloud_sql_clone_instance_test.go @@ -169,11 +169,10 @@ func TestCloneInstanceToolEndpoints(t *testing.T) { want: `{"name":"op2","status":"PENDING"}`, }, { - name: "missing destination instance name", - toolName: "clone-instance", - body: `{"project": "p1", "sourceInstanceName": "source-instance"}`, - expectError: true, - errorStatus: http.StatusBadRequest, + name: "missing destination instance name", + toolName: "clone-instance", + body: `{"project": "p1", "sourceInstanceName": "source-instance"}`, + want: `{"error":"parameter \"destinationInstanceName\" is required"}`, }, } diff --git a/tests/cloudsql/cloud_sql_create_backup_test.go b/tests/cloudsql/cloud_sql_create_backup_test.go index daebe9a732..7155e5e964 100644 --- a/tests/cloudsql/cloud_sql_create_backup_test.go +++ b/tests/cloudsql/cloud_sql_create_backup_test.go @@ -158,11 +158,10 @@ func TestCreateBackupToolEndpoints(t *testing.T) { want: `{"name":"op1","status":"PENDING"}`, }, { - name: "missing instance name", - toolName: "create-backup", - body: `{"project": "p1", "escription": "invalid"}`, - expectError: true, - errorStatus: http.StatusBadRequest, + name: "missing instance name", + toolName: "create-backup", + body: `{"project": "p1", "escription": "invalid"}`, + want: `{"error":"parameter \"instance\" is required"}`, }, } diff --git a/tests/cloudsql/cloud_sql_create_database_test.go b/tests/cloudsql/cloud_sql_create_database_test.go index c68d7dfb12..a9ef3ff2fb 100644 --- a/tests/cloudsql/cloud_sql_create_database_test.go +++ b/tests/cloudsql/cloud_sql_create_database_test.go @@ -155,11 +155,10 @@ func TestCreateDatabaseToolEndpoints(t *testing.T) { want: `{"name":"op1","status":"PENDING"}`, }, { - name: "missing name", - toolName: "create-database", - body: `{"project": "p1", "instance": "i1"}`, - expectError: true, - errorStatus: http.StatusBadRequest, + name: "missing name", + toolName: "create-database", + body: `{"project": "p1", "instance": "i1"}`, + want: `{"error":"parameter \"name\" is required"}`, }, } diff --git a/tests/cloudsql/cloud_sql_create_users_test.go b/tests/cloudsql/cloud_sql_create_users_test.go index 77978c4506..e4b8bd0b2c 100644 --- a/tests/cloudsql/cloud_sql_create_users_test.go +++ b/tests/cloudsql/cloud_sql_create_users_test.go @@ -167,11 +167,10 @@ func TestCreateUsersToolEndpoints(t *testing.T) { want: `{"name":"op2","status":"PENDING"}`, }, { - name: "missing password for built-in user", - toolName: "create-user", - body: `{"project": "p1", "instance": "i1", "name": "test-user", "iamUser": false}`, - expectError: true, - errorStatus: http.StatusBadRequest, + name: "missing password for built-in user", + toolName: "create-user", + body: `{"project": "p1", "instance": "i1", "name": "test-user", "iamUser": false}`, + want: `{"error":"missing 'password' parameter for non-IAM user"}`, }, } diff --git a/tests/cloudsql/cloud_sql_list_databases_test.go b/tests/cloudsql/cloud_sql_list_databases_test.go index 34719d2b03..9d49f45d25 100644 --- a/tests/cloudsql/cloud_sql_list_databases_test.go +++ b/tests/cloudsql/cloud_sql_list_databases_test.go @@ -138,11 +138,10 @@ func TestListDatabasesToolEndpoints(t *testing.T) { want: `[{"name":"db1","charset":"utf8","collation":"utf8_general_ci"},{"name":"db2","charset":"utf8mb4","collation":"utf8mb4_unicode_ci"}]`, }, { - name: "missing instance", - toolName: "list-databases", - body: `{"project": "p1"}`, - expectError: true, - errorStatus: http.StatusBadRequest, + name: "missing instance", + toolName: "list-databases", + body: `{"project": "p1"}`, + want: `{"error":"parameter \"instance\" is required"}`, }, } @@ -181,12 +180,26 @@ func TestListDatabasesToolEndpoints(t *testing.T) { t.Fatalf("failed to decode response: %v", err) } + if strings.Contains(result.Result, `"error":`) { + var gotMap, wantMap map[string]any + if err := json.Unmarshal([]byte(result.Result), &gotMap); err != nil { + t.Fatalf("failed to unmarshal result error object: %v", err) + } + if err := json.Unmarshal([]byte(tc.want), &wantMap); err != nil { + t.Fatalf("failed to unmarshal want error object: %v", err) + } + if !reflect.DeepEqual(gotMap, wantMap) { + t.Fatalf("unexpected error result: got %+v, want %+v", gotMap, wantMap) + } + return + } + var got, want []map[string]any if err := json.Unmarshal([]byte(result.Result), &got); err != nil { - t.Fatalf("failed to unmarshal result: %v", err) + t.Fatalf("failed to unmarshal result array: %v. Result was: %s", err, result.Result) } if err := json.Unmarshal([]byte(tc.want), &want); err != nil { - t.Fatalf("failed to unmarshal want: %v", err) + t.Fatalf("failed to unmarshal want array: %v", err) } if !reflect.DeepEqual(got, want) { diff --git a/tests/cloudsql/cloud_sql_restore_backup_test.go b/tests/cloudsql/cloud_sql_restore_backup_test.go index 970ad16164..47a0411945 100644 --- a/tests/cloudsql/cloud_sql_restore_backup_test.go +++ b/tests/cloudsql/cloud_sql_restore_backup_test.go @@ -23,7 +23,6 @@ import ( "net/http" "net/http/httptest" "net/url" - "reflect" "regexp" "strings" "testing" @@ -95,7 +94,11 @@ func (h *masterRestoreBackupHandler) ServeHTTP(w http.ResponseWriter, r *http.Re response = map[string]any{"name": "op1", "status": "PENDING"} statusCode = http.StatusOK default: - http.Error(w, fmt.Sprintf("unhandled restore request body: %v", body), http.StatusInternalServerError) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{ + "error": `oaraneter "backup_id" is required`, + }) return } @@ -178,25 +181,22 @@ func TestRestoreBackupToolEndpoints(t *testing.T) { want: `{"name":"op1","status":"PENDING"}`, }, { - name: "missing source instance info for standard backup", - toolName: "restore-backup", - body: `{"target_project": "p1", "target_instance": "instance-project-level", "backup_id": "12345"}`, - expectError: true, - errorStatus: http.StatusBadRequest, + name: "missing source instance info for standard backup", + toolName: "restore-backup", + body: `{"target_project": "p1", "target_instance": "instance-project-level", "backup_id": "12345"}`, + want: `{"error":"error processing GCP request: source project and instance are required when restoring via backup ID"}`, }, { - name: "missing backup identifier", - toolName: "restore-backup", - body: `{"target_project": "p1", "target_instance": "instance-project-level"}`, - expectError: true, - errorStatus: http.StatusBadRequest, + name: "missing backup identifier", + toolName: "restore-backup", + body: `{"target_project": "p1", "target_instance": "instance-project-level"}`, + want: `{"error":"parameter \"backup_id\" is required"}`, }, { - name: "missing target instance info", - toolName: "restore-backup", - body: `{"backup_id": "12345"}`, - expectError: true, - errorStatus: http.StatusBadRequest, + name: "missing target instance info", + toolName: "restore-backup", + body: `{"backup_id": "12345"}`, + want: `{"error":"parameter \"target_project\" is required"}`, }, } @@ -232,19 +232,14 @@ func TestRestoreBackupToolEndpoints(t *testing.T) { Result string `json:"result"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - t.Fatalf("failed to decode response: %v", err) + t.Fatalf("failed to decode response envelope: %v", err) } - var got, want map[string]any - if err := json.Unmarshal([]byte(result.Result), &got); err != nil { - t.Fatalf("failed to unmarshal result: %v", err) - } - if err := json.Unmarshal([]byte(tc.want), &want); err != nil { - t.Fatalf("failed to unmarshal want: %v", err) - } + got := strings.TrimSpace(result.Result) + want := strings.TrimSpace(tc.want) - if !reflect.DeepEqual(got, want) { - t.Fatalf("unexpected result: got %+v, want %+v", got, want) + if got != want { + t.Fatalf("unexpected result string:\n got: %s\nwant: %s", got, want) } }) } diff --git a/tests/cloudsql/cloudsql_wait_for_operation_test.go b/tests/cloudsql/cloudsql_wait_for_operation_test.go index 33c48077f2..e8225f8380 100644 --- a/tests/cloudsql/cloudsql_wait_for_operation_test.go +++ b/tests/cloudsql/cloudsql_wait_for_operation_test.go @@ -206,10 +206,10 @@ func TestCloudSQLWaitToolEndpoints(t *testing.T) { wantSubstring: true, }, { - name: "failed operation", - toolName: "wait-for-op2", - body: `{"project": "p1", "operation": "op2"}`, - expectError: true, + name: "failed operation - agent error", + toolName: "wait-for-op2", + body: `{"project": "p1", "operation": "op2"}`, + wantSubstring: true, }, { name: "non-database create operation", diff --git a/tests/cloudsqlmssql/cloud_sql_mssql_create_instance_integration_test.go b/tests/cloudsqlmssql/cloud_sql_mssql_create_instance_integration_test.go index f468869656..4ae8c0a7e9 100644 --- a/tests/cloudsqlmssql/cloud_sql_mssql_create_instance_integration_test.go +++ b/tests/cloudsqlmssql/cloud_sql_mssql_create_instance_integration_test.go @@ -198,11 +198,10 @@ func TestCreateInstanceToolEndpoints(t *testing.T) { want: `{"name":"op2","status":"RUNNING"}`, }, { - name: "missing required parameter", - toolName: "create-instance-prod", - body: `{"name": "instance1"}`, - expectError: true, - errorStatus: http.StatusBadRequest, + name: "missing required parameter", + toolName: "create-instance-prod", + body: `{"name": "instance1"}`, + want: `{"error":"parameter \"project\" is required"}`, }, } diff --git a/tests/cloudsqlmysql/cloud_sql_mysql_create_instance_integration_test.go b/tests/cloudsqlmysql/cloud_sql_mysql_create_instance_integration_test.go index 4af92f7648..45975103aa 100644 --- a/tests/cloudsqlmysql/cloud_sql_mysql_create_instance_integration_test.go +++ b/tests/cloudsqlmysql/cloud_sql_mysql_create_instance_integration_test.go @@ -199,11 +199,10 @@ func TestCreateInstanceToolEndpoints(t *testing.T) { want: `{"name":"op2","status":"RUNNING"}`, }, { - name: "missing required parameter", - toolName: "create-instance-prod", - body: `{"name": "instance1"}`, - expectError: true, - errorStatus: http.StatusBadRequest, + name: "missing required parameter", + toolName: "create-instance-prod", + body: `{"name": "instance1"}`, + want: `{"error":"parameter \"project\" is required"}`, }, } diff --git a/tests/cloudsqlpg/cloud_sql_pg_create_instances_test.go b/tests/cloudsqlpg/cloud_sql_pg_create_instances_test.go index aaef8f5bcc..df7350801f 100644 --- a/tests/cloudsqlpg/cloud_sql_pg_create_instances_test.go +++ b/tests/cloudsqlpg/cloud_sql_pg_create_instances_test.go @@ -200,11 +200,10 @@ func TestCreateInstanceToolEndpoints(t *testing.T) { want: `{"name":"op2","status":"RUNNING"}`, }, { - name: "missing required parameter", - toolName: "create-instance-prod", - body: `{"name": "instance1"}`, - expectError: true, - errorStatus: http.StatusBadRequest, + name: "missing required parameter", + toolName: "create-instance-prod", + body: `{"name": "instance1"}`, + want: `{"error":"parameter \"project\" is required"}`, }, } diff --git a/tests/cloudsqlpg/cloud_sql_pg_upgrade_precheck_test.go b/tests/cloudsqlpg/cloud_sql_pg_upgrade_precheck_test.go index 881e4bee15..118680e3b1 100644 --- a/tests/cloudsqlpg/cloud_sql_pg_upgrade_precheck_test.go +++ b/tests/cloudsqlpg/cloud_sql_pg_upgrade_precheck_test.go @@ -276,25 +276,22 @@ func TestPreCheckToolEndpoints(t *testing.T) { name: "instance not found", toolName: "precheck-tool", body: `{"project": "p1", "instance": "instance-notfound", "targetDatabaseVersion": "POSTGRES_18"}`, + want: `{"error":"failed to access GCP resource: googleapi: got HTTP response code 403 with body: Not authorized to access instance\n"}`, expectError: true, - errorStatus: http.StatusBadRequest, - errorMsg: "Not authorized to access instance", + errorStatus: http.StatusInternalServerError, + errorMsg: "failed to access GCP resource: googleapi: got HTTP response code 403", }, { - name: "missing required parameter - project", - toolName: "precheck-tool", - body: `{"instance": "instance-ok", "targetDatabaseVersion": "POSTGRES_18"}`, - expectError: true, - errorStatus: http.StatusBadRequest, - errorMsg: "parameter \\\"project\\\" is required", + name: "missing required parameter - project", + toolName: "precheck-tool", + body: `{"instance": "instance-ok", "targetDatabaseVersion": "POSTGRES_18"}`, + want: `{"error":"parameter \"project\" is required"}`, }, { - name: "missing required parameter - instance", - toolName: "precheck-tool", - body: `{"project": "p1", "targetDatabaseVersion": "POSTGRES_18"}`, // Missing instance - expectError: true, - errorStatus: http.StatusBadRequest, - errorMsg: "parameter \\\"instance\\\" is required", + name: "missing required parameter - instance", + toolName: "precheck-tool", + body: `{"project": "p1", "targetDatabaseVersion": "POSTGRES_18"}`, // Missing instance + want: `{"error":"parameter \"instance\" is required"}`, }, { name: "missing parameter - targetDatabaseVersion", diff --git a/tests/common.go b/tests/common.go index d200d59dd6..480143dbce 100644 --- a/tests/common.go +++ b/tests/common.go @@ -557,7 +557,7 @@ func GetCockroachDBWants() (string, string, string, string) { // CockroachDB formats syntax errors differently than PostgreSQL: // - Uses lowercase for SQL keywords in error messages // - Uses format: 'at or near "token": syntax error' instead of 'syntax error at or near "TOKEN"' - mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: ERROR: at or near \"selec\": syntax error (SQLSTATE 42601)"}],"isError":true}}` + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"error processing request: unable to execute query: ERROR: at or near \"selec\": syntax error (SQLSTATE 42601)"}],"isError":true}}` createTableStatement := `"CREATE TABLE t (id INT PRIMARY KEY, name TEXT)"` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"?column?\":1}"}]}}` return select1Want, mcpMyFailToolWant, createTableStatement, mcpSelect1Want @@ -622,7 +622,7 @@ func GetMySQLTmplToolStatement() (string, string) { // GetPostgresWants return the expected wants for postgres func GetPostgresWants() (string, string, string, string) { select1Want := "[{\"?column?\":1}]" - mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: ERROR: syntax error at or near \"SELEC\" (SQLSTATE 42601)"}],"isError":true}}` + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"error processing request: unable to execute query: ERROR: syntax error at or near \"SELEC\" (SQLSTATE 42601)"}],"isError":true}}` createTableStatement := `"CREATE TABLE t (id SERIAL PRIMARY KEY, name TEXT)"` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"?column?\":1}"}]}}` return select1Want, mcpMyFailToolWant, createTableStatement, mcpSelect1Want @@ -631,7 +631,7 @@ func GetPostgresWants() (string, string, string, string) { // GetMSSQLWants return the expected wants for mssql func GetMSSQLWants() (string, string, string, string) { select1Want := "[{\"\":1}]" - mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: mssql: Could not find stored procedure 'SELEC'."}],"isError":true}}` + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"error processing request: unable to execute query: mssql: Could not find stored procedure 'SELEC'."}],"isError":true}}` createTableStatement := `"CREATE TABLE t (id INT IDENTITY(1,1) PRIMARY KEY, name NVARCHAR(MAX))"` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"\":1}"}]}}` return select1Want, mcpMyFailToolWant, createTableStatement, mcpSelect1Want @@ -640,7 +640,7 @@ func GetMSSQLWants() (string, string, string, string) { // GetMySQLWants return the expected wants for mysql func GetMySQLWants() (string, string, string, string) { select1Want := "[{\"1\":1}]" - mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: Error 1064 (42000): You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near 'SELEC 1' at line 1"}],"isError":true}}` + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"error processing request: unable to execute query: Error 1064 (42000): You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near 'SELEC 1' at line 1"}],"isError":true}}` createTableStatement := `"CREATE TABLE t (id SERIAL PRIMARY KEY, name TEXT)"` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":1}"}]}}` return select1Want, mcpMyFailToolWant, createTableStatement, mcpSelect1Want diff --git a/tests/couchbase/couchbase_integration_test.go b/tests/couchbase/couchbase_integration_test.go index d78c71b82d..1d7bb6bfc4 100644 --- a/tests/couchbase/couchbase_integration_test.go +++ b/tests/couchbase/couchbase_integration_test.go @@ -137,7 +137,7 @@ func TestCouchbaseToolEndpoints(t *testing.T) { // Get configs for tests select1Want := "[{\"$1\":1}]" - mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: parsing failure | {\"statement\":\"SELEC 1;\"` + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"error processing request: unable to execute query: parsing failure | {\"statement\":\"SELEC 1;\"` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"$1\":1}"}]}}` tmplSelectId1Want := "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]" selectAllWant := "[{\"age\":21,\"id\":1,\"name\":\"Alex\"},{\"age\":100,\"id\":2,\"name\":\"Alice\"}]" diff --git a/tests/dataform/dataform_integration_test.go b/tests/dataform/dataform_integration_test.go index d235998359..3be737fca0 100644 --- a/tests/dataform/dataform_integration_test.go +++ b/tests/dataform/dataform_integration_test.go @@ -109,13 +109,13 @@ func TestDataformCompileTool(t *testing.T) { { name: "missing parameter", reqBody: `{}`, - wantStatus: http.StatusBadRequest, - wantBody: `parameter \"project_dir\" is required`, + wantStatus: http.StatusOK, + wantBody: `error`, }, { name: "non-existent directory", reqBody: fmt.Sprintf(`{"project_dir":"%s"}`, nonExistentDir), - wantStatus: http.StatusBadRequest, + wantStatus: http.StatusOK, wantBody: "error executing dataform compile", }, } diff --git a/tests/dataplex/dataplex_integration_test.go b/tests/dataplex/dataplex_integration_test.go index 1dcd72aeb3..602c74ec2e 100644 --- a/tests/dataplex/dataplex_integration_test.go +++ b/tests/dataplex/dataplex_integration_test.go @@ -517,8 +517,11 @@ func runDataplexSearchEntriesToolInvokeTest(t *testing.T, tableName string, data t.Fatalf("expected entry to have key '%s', but it was not found in %v", tc.wantContentKey, entry) } } else { - if len(entries) != 0 { - t.Fatalf("expected 0 entries, but got %d", len(entries)) + isResultEmpty := resultStr == "" || resultStr == "[]" || resultStr == "null" + hasError := strings.Contains(resultStr, `"error":`) + + if !isResultEmpty && !hasError { + t.Fatalf("expected an empty result or error message, but got: %s", resultStr) } } }) @@ -584,7 +587,7 @@ func runDataplexLookupEntryToolInvokeTest(t *testing.T, tableName string, datase api: "http://127.0.0.1:5000/api/tool/my-dataplex-lookup-entry-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"name\":\"projects/%s/locations/us\", \"entry\":\"projects/%s/locations/us/entryGroups/@bigquery/entries/bigquery.googleapis.com/projects/%s/datasets/%s\"}", DataplexProject, DataplexProject, DataplexProject, "non-existent-dataset"))), - wantStatusCode: 400, + wantStatusCode: 200, expectResult: false, }, { @@ -602,7 +605,7 @@ func runDataplexLookupEntryToolInvokeTest(t *testing.T, tableName string, datase api: "http://127.0.0.1:5000/api/tool/my-dataplex-lookup-entry-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"name\":\"projects/%s/locations/us\", \"entry\":\"projects/%s/locations/us/entryGroups/@bigquery/entries/bigquery.googleapis.com/projects/%s/datasets/%s/tables/%s\", \"view\": %d}", DataplexProject, DataplexProject, DataplexProject, datasetName, tableName, 3))), - wantStatusCode: 400, + wantStatusCode: 200, expectResult: false, }, { @@ -643,42 +646,44 @@ func runDataplexLookupEntryToolInvokeTest(t *testing.T, tableName string, datase t.Fatalf("Error parsing response body: %v", err) } + resultStr, hasResult := result["result"].(string) + if tc.expectResult { - resultStr, ok := result["result"].(string) - if !ok { - t.Fatalf("Expected 'result' field to be a string on success, got %T", result["result"]) - } - if resultStr == "" || resultStr == "{}" || resultStr == "null" { - t.Fatal("Expected an entry, but got empty result") + if !hasResult || resultStr == "" || resultStr == "{}" || resultStr == "null" { + t.Fatalf("Expected a result, but got: %v", result) } var entry map[string]interface{} if err := json.Unmarshal([]byte(resultStr), &entry); err != nil { - t.Fatalf("Error unmarshalling result string into entry map: %v", err) + t.Fatalf("Error unmarshalling result string: %v. Raw result: %s", err, resultStr) } if _, ok := entry[tc.wantContentKey]; !ok { t.Fatalf("Expected entry to have key '%s', but it was not found in %v", tc.wantContentKey, entry) } - if _, ok := entry[tc.dontWantContentKey]; ok { - t.Fatalf("Expected entry to not have key '%s', but it was found in %v", tc.dontWantContentKey, entry) + if tc.dontWantContentKey != "" { + if _, ok := entry[tc.dontWantContentKey]; ok { + t.Fatalf("Expected entry to NOT have key '%s', but it was found", tc.dontWantContentKey) + } } if tc.aspectCheck { - // Check length of aspects aspects, ok := entry["aspects"].(map[string]interface{}) - if !ok { - t.Fatalf("Expected 'aspects' to be a map, got %T", aspects) - } - if len(aspects) != 1 { + if !ok || len(aspects) != 1 { t.Fatalf("Expected exactly one aspect, but got %d", len(aspects)) } } - } else { // Handle expected error response - _, ok := result["error"] - if !ok { - t.Fatalf("Expected 'error' field in response, got %v", result) + } else { + foundError := false + if _, ok := result["error"]; ok { + foundError = true + } else if hasResult && strings.Contains(resultStr, `"error"`) { + foundError = true + } + + if !foundError { + t.Fatalf("Expected an error in response, but none was found. Response: %v", result) } } }) diff --git a/tests/firebird/firebird_integration_test.go b/tests/firebird/firebird_integration_test.go index 256b6e7e66..8644e0ffe5 100644 --- a/tests/firebird/firebird_integration_test.go +++ b/tests/firebird/firebird_integration_test.go @@ -305,7 +305,7 @@ func getFirebirdAuthToolInfo(tableName string) ([]string, string, string, []any) func getFirebirdWants() (string, string, string, string) { select1Want := `[{"constant":1}]` - mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: Dynamic SQL Error\nSQL error code = -104\nToken unknown - line 1, column 1\nSELEC\n"}],"isError":true}}` + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"error processing request: unable to execute query: Dynamic SQL Error\nSQL error code = -104\nToken unknown - line 1, column 1\nSELEC\n"}],"isError":true}}` createTableStatement := `"CREATE TABLE t (id INTEGER PRIMARY KEY, name VARCHAR(50))"` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"constant\":1}"}]}}` return select1Want, mcpMyFailToolWant, createTableStatement, mcpSelect1Want diff --git a/tests/http/http_integration_test.go b/tests/http/http_integration_test.go index eb928f3d8d..6ec82f4ce7 100644 --- a/tests/http/http_integration_test.go +++ b/tests/http/http_integration_test.go @@ -404,37 +404,41 @@ func runQueryParamInvokeTest(t *testing.T) { } } -// runToolInvoke runs the tool invoke endpoint func runAdvancedHTTPInvokeTest(t *testing.T) { // Test HTTP tool invoke endpoint invokeTcs := []struct { name string api string requestHeader map[string]string - requestBody io.Reader + requestBody func() io.Reader want string - isErr bool + isAgentErr bool }{ { name: "invoke my-advanced-tool", api: "http://127.0.0.1:5000/api/tool/my-advanced-tool/invoke", requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(`{"animalArray": ["rabbit", "ostrich", "whale"], "id": 3, "path": "tool3", "country": "US", "X-Other-Header": "test"}`)), - want: `"hello world"`, - isErr: false, + requestBody: func() io.Reader { + return bytes.NewBuffer([]byte(`{"animalArray": ["rabbit", "ostrich", "whale"], "id": 3, "path": "tool3", "country": "US", "X-Other-Header": "test"}`)) + }, + want: `"hello world"`, + isAgentErr: false, }, { name: "invoke my-advanced-tool with wrong params", api: "http://127.0.0.1:5000/api/tool/my-advanced-tool/invoke", requestHeader: map[string]string{}, - requestBody: bytes.NewBuffer([]byte(`{"animalArray": ["rabbit", "ostrich", "whale"], "id": 4, "path": "tool3", "country": "US", "X-Other-Header": "test"}`)), - isErr: true, + requestBody: func() io.Reader { + return bytes.NewBuffer([]byte(`{"animalArray": ["rabbit", "ostrich", "whale"], "id": 4, "path": "tool3", "country": "US", "X-Other-Header": "test"}`)) + }, + want: "error processing request: unexpected status code: 400, response body: Bad Request: Incorrect query parameter: id, actual: [2 1 4]", + isAgentErr: true, }, } + for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { - // Send Tool invocation request - req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) + req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody()) if err != nil { t.Fatalf("unable to create request: %s", err) } @@ -442,33 +446,54 @@ func runAdvancedHTTPInvokeTest(t *testing.T) { for k, v := range tc.requestHeader { req.Header.Add(k, v) } + resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("unable to send request: %s", err) } defer resp.Body.Close() + // As you noted, the toolbox wraps errors in a 200 OK if resp.StatusCode != http.StatusOK { - if tc.isErr == true { - return - } bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + t.Fatalf("expected status 200 from toolbox, got %d: %s", resp.StatusCode, string(bodyBytes)) } - // Check response body - var body map[string]interface{} - err = json.NewDecoder(resp.Body).Decode(&body) - if err != nil { - t.Fatalf("error parsing response body") - } - got, ok := body["result"].(string) - if !ok { - t.Fatalf("unable to find result in response body") + // Decode the response body into a map + var body map[string]any + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + t.Fatalf("failed to decode response: %v", err) } - if got != tc.want { - t.Fatalf("unexpected value: got %q, want %q", got, tc.want) + if tc.isAgentErr { + resStr, ok := body["result"].(string) + if !ok { + t.Fatalf("expected 'result' field as string in response body, got: %v", body) + } + + var resMap map[string]any + if err := json.Unmarshal([]byte(resStr), &resMap); err != nil { + t.Fatalf("failed to unmarshal result string: %v", err) + } + + gotErr, ok := resMap["error"].(string) + if !ok { + t.Fatalf("expected 'error' field inside result, got: %v", resMap) + } + + if !strings.Contains(gotErr, tc.want) { + t.Fatalf("unexpected error message: got %q, want it to contain %q", gotErr, tc.want) + } + } else { + got, ok := body["result"].(string) + if !ok { + resBytes, _ := json.Marshal(body["result"]) + got = string(resBytes) + } + + if got != tc.want { + t.Fatalf("unexpected result: got %q, want %q", got, tc.want) + } } }) } @@ -512,13 +537,13 @@ func getHTTPToolsConfig(sourceConfig map[string]any, toolType string) map[string "description": "some description", "queryParams": []parameters.Parameter{ parameters.NewIntParameter("id", "user ID")}, + "bodyParams": []parameters.Parameter{parameters.NewStringParameter("name", "user name")}, "requestBody": `{ "age": 36, "name": "{{.name}}" } `, - "bodyParams": []parameters.Parameter{parameters.NewStringParameter("name", "user name")}, - "headers": map[string]string{"Content-Type": "application/json"}, + "headers": map[string]string{"Content-Type": "application/json"}, }, "my-tool-by-id": map[string]any{ "type": toolType, diff --git a/tests/mariadb/mariadb_integration_test.go b/tests/mariadb/mariadb_integration_test.go index df3f4fb60c..29025b2554 100644 --- a/tests/mariadb/mariadb_integration_test.go +++ b/tests/mariadb/mariadb_integration_test.go @@ -336,7 +336,7 @@ func RunMariDBListTablesTest(t *testing.T, databaseName, tableNameParam, tableNa // GetMariaDBWants return the expected wants for mariaDB func GetMariaDBWants() (string, string, string, string) { select1Want := `[{"1":1}]` - mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: Error 1064 (42000): You have an error in your SQL syntax; check the manual that corresponds to your MariaDB server version for the right syntax to use near 'SELEC 1' at line 1"}],"isError":true}}` + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"error processing request: unable to execute query: Error 1064 (42000): You have an error in your SQL syntax; check the manual that corresponds to your MariaDB server version for the right syntax to use near 'SELEC 1' at line 1"}],"isError":true}}` createTableStatement := `"CREATE TABLE t (id INT AUTO_INCREMENT PRIMARY KEY, name TEXT)"` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":1}"}]}}` return select1Want, mcpMyFailToolWant, createTableStatement, mcpSelect1Want diff --git a/tests/mongodb/mongodb_integration_test.go b/tests/mongodb/mongodb_integration_test.go index a40178ca26..dc03a1c6c3 100644 --- a/tests/mongodb/mongodb_integration_test.go +++ b/tests/mongodb/mongodb_integration_test.go @@ -354,6 +354,7 @@ func runToolUpdateInvokeTest(t *testing.T, update1Want, updateManyWant string) { }) } } + func runToolAggregateInvokeTest(t *testing.T, aggregate1Want string, aggregateManyWant string) { // Test tool invoke endpoint invokeTcs := []struct { @@ -385,8 +386,8 @@ func runToolAggregateInvokeTest(t *testing.T, aggregate1Want string, aggregateMa api: "http://127.0.0.1:5000/api/tool/my-read-only-aggregate-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{ "name" : "ToBeAggregated" }`)), - want: "", - isErr: true, + want: `{"error":"error processing request: this is not a read-only pipeline: {\"$out\":\"target_collection\"}"}`, + isErr: false, }, { name: "invoke my-read-write-aggregate-tool", diff --git a/tests/neo4j/neo4j_integration_test.go b/tests/neo4j/neo4j_integration_test.go index a9d41babcc..c1dfd27544 100644 --- a/tests/neo4j/neo4j_integration_test.go +++ b/tests/neo4j/neo4j_integration_test.go @@ -287,25 +287,37 @@ func TestNeo4jToolEndpoints(t *testing.T) { }, }, { - name: "invoke my-simple-execute-cypher-tool with dry_run and invalid syntax", - api: "http://127.0.0.1:5000/api/tool/my-simple-execute-cypher-tool/invoke", - requestBody: bytes.NewBuffer([]byte(`{"cypher": "RTN 1", "dry_run": true}`)), - wantStatus: http.StatusBadRequest, - wantErrorSubstring: "unable to execute query", + name: "invoke my-simple-execute-cypher-tool with dry_run and invalid syntax", + api: "http://127.0.0.1:5000/api/tool/my-simple-execute-cypher-tool/invoke", + requestBody: bytes.NewBuffer([]byte(`{"cypher": "RTN 1", "dry_run": true}`)), + wantStatus: http.StatusOK, + validateFunc: func(t *testing.T, body string) { + if !strings.Contains(body, "unable to execute query") { + t.Errorf("expected error message not found in body: %s", body) + } + }, }, { - name: "invoke readonly tool with write query", - api: "http://127.0.0.1:5000/api/tool/my-readonly-execute-cypher-tool/invoke", - requestBody: bytes.NewBuffer([]byte(`{"cypher": "CREATE (n:TestNode)"}`)), - wantStatus: http.StatusBadRequest, - wantErrorSubstring: "this tool is read-only and cannot execute write queries", + name: "invoke readonly tool with write query", + api: "http://127.0.0.1:5000/api/tool/my-readonly-execute-cypher-tool/invoke", + requestBody: bytes.NewBuffer([]byte(`{"cypher": "CREATE (n:TestNode)"}`)), + wantStatus: http.StatusOK, + validateFunc: func(t *testing.T, body string) { + if !strings.Contains(body, "this tool is read-only and cannot execute write queries") { + t.Errorf("expected error message not found in body: %s", body) + } + }, }, { - name: "invoke readonly tool with write query and dry_run", - api: "http://127.0.0.1:5000/api/tool/my-readonly-execute-cypher-tool/invoke", - requestBody: bytes.NewBuffer([]byte(`{"cypher": "CREATE (n:TestNode)", "dry_run": true}`)), - wantStatus: http.StatusBadRequest, - wantErrorSubstring: "this tool is read-only and cannot execute write queries", + name: "invoke readonly tool with write query and dry_run", + api: "http://127.0.0.1:5000/api/tool/my-readonly-execute-cypher-tool/invoke", + requestBody: bytes.NewBuffer([]byte(`{"cypher": "CREATE (n:TestNode)", "dry_run": true}`)), + wantStatus: http.StatusOK, + validateFunc: func(t *testing.T, body string) { + if !strings.Contains(body, "this tool is read-only and cannot execute write queries") { + t.Errorf("expected error message not found in body: %s", body) + } + }, }, { name: "invoke my-schema-tool", diff --git a/tests/oceanbase/oceanbase_integration_test.go b/tests/oceanbase/oceanbase_integration_test.go index c81f96db07..c6394cacb3 100644 --- a/tests/oceanbase/oceanbase_integration_test.go +++ b/tests/oceanbase/oceanbase_integration_test.go @@ -166,7 +166,7 @@ func getOceanBaseTmplToolStatement() (string, string) { // OceanBase specific expected results func getOceanBaseWants() (string, string, string, string) { select1Want := "[{\"1\":1}]" - mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: Error 1064 (42000): You have an error in your SQL syntax; check the manual that corresponds to your OceanBase version for the right syntax to use near 'SELEC 1;' at line 1"}],"isError":true}}` + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"error processing request: unable to execute query: Error 1064 (42000): You have an error in your SQL syntax; check the manual that corresponds to your OceanBase version for the right syntax to use near 'SELEC 1;' at line 1"}],"isError":true}}` createTableStatement := `"CREATE TABLE t (id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255))"` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":1}"}]}}` return select1Want, mcpMyFailToolWant, createTableStatement, mcpSelect1Want diff --git a/tests/oracle/oracle_integration_test.go b/tests/oracle/oracle_integration_test.go index 75f5fc00de..bbd8abbcdf 100644 --- a/tests/oracle/oracle_integration_test.go +++ b/tests/oracle/oracle_integration_test.go @@ -119,7 +119,7 @@ func TestOracleSimpleToolEndpoints(t *testing.T) { // Get configs for tests select1Want := "[{\"1\":1}]" - mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: dpiStmt_execute: ORA-00900: invalid SQL statement"}],"isError":true}}` + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"error processing request: unable to execute query: dpiStmt_execute: ORA-00900: invalid SQL statement"}],"isError":true}}` createTableStatement := `"CREATE TABLE t (id NUMBER GENERATED AS IDENTITY PRIMARY KEY, name VARCHAR2(255))"` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":1}"}]}}` diff --git a/tests/serverlessspark/serverless_spark_integration_test.go b/tests/serverlessspark/serverless_spark_integration_test.go index 5ac8df1b1b..dbb4670830 100644 --- a/tests/serverlessspark/serverless_spark_integration_test.go +++ b/tests/serverlessspark/serverless_spark_integration_test.go @@ -203,14 +203,14 @@ func TestServerlessSparkToolEndpoints(t *testing.T) { name: "zero page size", toolName: "list-batches", request: map[string]any{"pageSize": 0}, - wantCode: http.StatusBadRequest, + wantCode: http.StatusOK, wantMsg: "pageSize must be positive: 0", }, { name: "negative page size", toolName: "list-batches", request: map[string]any{"pageSize": -1}, - wantCode: http.StatusBadRequest, + wantCode: http.StatusOK, wantMsg: "pageSize must be positive: -1", }, } @@ -250,14 +250,14 @@ func TestServerlessSparkToolEndpoints(t *testing.T) { name: "missing batch", toolName: "get-batch", request: map[string]any{"name": "INVALID_BATCH"}, - wantCode: http.StatusBadRequest, - wantMsg: fmt.Sprintf("Not found: Batch projects/%s/locations/%s/batches/INVALID_BATCH", serverlessSparkProject, serverlessSparkLocation), + wantCode: http.StatusOK, + wantMsg: fmt.Sprintf("error processing GCP request: failed to get batch: rpc error: code = NotFound desc = Not found: Batch projects/%s/locations/%s/batches/INVALID_BATCH", serverlessSparkProject, serverlessSparkLocation), }, { name: "full batch name", toolName: "get-batch", request: map[string]any{"name": missingBatchFullName}, - wantCode: http.StatusBadRequest, + wantCode: http.StatusOK, wantMsg: fmt.Sprintf("name must be a short batch name without '/': %s", missingBatchFullName), }, } @@ -352,13 +352,13 @@ func TestServerlessSparkToolEndpoints(t *testing.T) { { name: "missing main file", request: map[string]any{}, - wantMsg: "parameter \\\"mainFile\\\" is required", + wantMsg: `{"error":"parameter \"mainFile\" is required"}`, }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { t.Parallel() - testError(t, "create-pyspark-batch", tc.request, http.StatusBadRequest, tc.wantMsg) + testError(t, "create-pyspark-batch", tc.request, http.StatusOK, tc.wantMsg) }) } }) @@ -478,7 +478,7 @@ func TestServerlessSparkToolEndpoints(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { t.Parallel() - testError(t, "create-spark-batch", tc.request, http.StatusBadRequest, tc.wantMsg) + testError(t, "create-spark-batch", tc.request, http.StatusOK, tc.wantMsg) }) } }) @@ -529,21 +529,21 @@ func TestServerlessSparkToolEndpoints(t *testing.T) { name: "missing op parameter", toolName: "cancel-batch", request: map[string]any{}, - wantCode: http.StatusBadRequest, - wantMsg: "parameter \\\"operation\\\" is required", + wantCode: http.StatusOK, + wantMsg: `{"error":"parameter \"operation\" is required"}`, }, { name: "nonexistent op", toolName: "cancel-batch", request: map[string]any{"operation": "INVALID_OPERATION"}, - wantCode: http.StatusBadRequest, - wantMsg: "Operation not found", + wantCode: http.StatusOK, + wantMsg: "error processing GCP request: failed to cancel operation: rpc error: code = NotFound desc = Operation not found", }, { name: "full op name", toolName: "cancel-batch", request: map[string]any{"operation": fullOpName}, - wantCode: http.StatusBadRequest, + wantCode: http.StatusOK, wantMsg: fmt.Sprintf("operation must be a short operation name without '/': %s", fullOpName), }, } @@ -556,7 +556,7 @@ func TestServerlessSparkToolEndpoints(t *testing.T) { }) t.Run("auth", func(t *testing.T) { t.Parallel() - runAuthTest(t, "cancel-batch-with-auth", map[string]any{"operation": "INVALID_OPERATION"}, http.StatusBadRequest) + runAuthTest(t, "cancel-batch-with-auth", map[string]any{"operation": "INVALID_OPERATION"}, http.StatusOK) }) }) }) @@ -1003,18 +1003,32 @@ func testError(t *testing.T, toolName string, request map[string]any, wantCode i } defer resp.Body.Close() - if resp.StatusCode != wantCode { - bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("response status code is not %d, got %d: %s", wantCode, resp.StatusCode, string(bodyBytes)) - } - bodyBytes, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("failed to read response body: %v", err) } - if !bytes.Contains(bodyBytes, []byte(wantMsg)) { - t.Fatalf("response body does not contain %q: %s", wantMsg, string(bodyBytes)) + if resp.StatusCode != wantCode { + t.Fatalf("response status code is not %d, got %d: %s", wantCode, resp.StatusCode, string(bodyBytes)) + } + + var body map[string]any + if err := json.Unmarshal(bodyBytes, &body); err != nil { + t.Fatalf("failed to unmarshal outer response: %v", err) + } + + var resultStr string + if res, ok := body["result"].(string); ok { + resultStr = res + } else if errMsg, ok := body["error"].(string); ok { + resultStr = errMsg + } else { + // If neither exists, check the raw bytes as a last resort + resultStr = string(bodyBytes) + } + + if !strings.Contains(resultStr, wantMsg) { + t.Fatalf("result string %q does not contain expected message %q", resultStr, wantMsg) } } diff --git a/tests/singlestore/singlestore_integration_test.go b/tests/singlestore/singlestore_integration_test.go index 5ada56d6f5..3806f205db 100644 --- a/tests/singlestore/singlestore_integration_test.go +++ b/tests/singlestore/singlestore_integration_test.go @@ -95,7 +95,7 @@ func getSingleStoreTmplToolStatement() (string, string) { // getSingleStoreWants return the expected wants for singlestore func getSingleStoreWants() (string, string, string, string) { select1Want := "[{\"1\":1}]" - mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: Error 1064 (42000): You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near 'SELEC 1' at line 1"}],"isError":true}}` + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"error processing request: unable to execute query: Error 1064 (42000): You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near 'SELEC 1' at line 1"}],"isError":true}}` createTableStatement := `"CREATE TABLE t (id BIGINT PRIMARY KEY, name TEXT)"` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":1}"}]}}` return select1Want, mcpMyFailToolWant, createTableStatement, mcpSelect1Want diff --git a/tests/snowflake/snowflake_integration_test.go b/tests/snowflake/snowflake_integration_test.go index ee07a86107..f7f8b14b18 100644 --- a/tests/snowflake/snowflake_integration_test.go +++ b/tests/snowflake/snowflake_integration_test.go @@ -222,7 +222,7 @@ func getSnowflakeTmplToolStatement() (string, string) { func getSnowflakeWants() (string, string, string, string) { select1Want := `[{"1":"1"}]` failInvocationWant := `unexpected 'SELEC'` - createTableStatement := `"CREATE TABLE t (id INTEGER AUTOINCREMENT PRIMARY KEY, name STRING)"` + createTableStatement := `"CREATE TABLE IF NOT EXISTS t (id INTEGER AUTOINCREMENT PRIMARY KEY, name STRING)"` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":\"1\"}"}]}}` return select1Want, failInvocationWant, createTableStatement, mcpSelect1Want } diff --git a/tests/sqlite/sqlite_integration_test.go b/tests/sqlite/sqlite_integration_test.go index ac01cbd0ca..9732e8ae9a 100644 --- a/tests/sqlite/sqlite_integration_test.go +++ b/tests/sqlite/sqlite_integration_test.go @@ -157,7 +157,7 @@ func TestSQLiteToolEndpoint(t *testing.T) { // Get configs for tests select1Want := "[{\"1\":1}]" - mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: SQL logic error: near \"SELEC\": syntax error (1)"}],"isError":true}}` + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"error processing request: unable to execute query: SQL logic error: near \"SELEC\": syntax error (1)"}],"isError":true}}` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":1}"}]}}` // Run tests @@ -237,8 +237,8 @@ func TestSQLiteExecuteSqlTool(t *testing.T) { { name: "invalid SQL", sql: "SELEC name FROM not_a_table", - wantStatus: 400, - wantBody: "SQL logic error", + wantStatus: 200, + wantBody: "error processing request: unable to execute query: SQL logic error", }, } diff --git a/tests/tidb/tidb_integration_test.go b/tests/tidb/tidb_integration_test.go index 8e9c5f6c7f..fc5d5126ee 100644 --- a/tests/tidb/tidb_integration_test.go +++ b/tests/tidb/tidb_integration_test.go @@ -78,7 +78,7 @@ func initTiDBConnectionPool(host, port, user, pass, dbname string, useSSL bool) // getTiDBWants return the expected wants for tidb func getTiDBWants() (string, string, string, string) { select1Want := "[{\"1\":1}]" - mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: Error 1064 (42000): You have an error in your SQL syntax; check the manual that corresponds to your TiDB version for the right syntax to use line 1 column 5 near \"SELEC 1;\" "}],"isError":true}}` + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"error processing request: unable to execute query: Error 1064 (42000): You have an error in your SQL syntax; check the manual that corresponds to your TiDB version for the right syntax to use line 1 column 5 near \"SELEC 1;\" "}],"isError":true}}` createTableStatement := `"CREATE TABLE t (id SERIAL PRIMARY KEY, name TEXT)"` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":1}"}]}}` return select1Want, mcpMyFailToolWant, createTableStatement, mcpSelect1Want diff --git a/tests/tool.go b/tests/tool.go index 6d839d0bf4..2af9a21705 100644 --- a/tests/tool.go +++ b/tests/tool.go @@ -311,8 +311,8 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp enabled: true, requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{}`)), - wantBody: "", - wantStatusCode: http.StatusBadRequest, + wantBody: `{"error":"parameter \"id\" is required"}`, + wantStatusCode: http.StatusOK, }, { name: "Invoke my-tool with insufficient parameters", @@ -320,8 +320,8 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp enabled: true, requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{"id": 1}`)), - wantBody: "", - wantStatusCode: http.StatusBadRequest, + wantBody: `{"error":"parameter \"name\" is required"}`, + wantStatusCode: http.StatusOK, }, { name: "invoke my-array-tool", @@ -635,6 +635,7 @@ func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement, select1Want requestBody io.Reader want string isErr bool + isAgentErr bool }{ { name: "invoke my-exec-sql-tool", @@ -673,7 +674,7 @@ func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement, select1Want api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{}`)), - isErr: true, + isAgentErr: true, }, { name: "Invoke my-auth-exec-sql-tool with auth token", @@ -702,14 +703,14 @@ func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement, select1Want api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT * FROM non_existent_table"}`)), - isErr: true, + isAgentErr: true, }, { name: "invoke my-exec-sql-tool with invalid ALTER SQL", api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{"sql":"ALTER TALE t ALTER COLUMN id DROP NOT NULL"}`)), - isErr: true, + isAgentErr: true, }, } for _, tc := range invokeTcs { @@ -722,6 +723,9 @@ func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement, select1Want } t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } + if tc.isAgentErr { + return + } // Check response body var body map[string]interface{} @@ -942,7 +946,7 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant, select1Want string, opti }, }, wantStatusCode: http.StatusUnauthorized, - wantBody: "{\"jsonrpc\":\"2.0\",\"id\":\"invoke my-auth-required-tool with invalid token\",\"error\":{\"code\":-32600,\"message\":\"unauthorized Tool call: Please make sure your specify correct auth headers: unauthorized\"}}", + wantBody: "{\"jsonrpc\":\"2.0\",\"id\":\"invoke my-auth-required-tool with invalid token\",\"error\":{\"code\":-32600,\"message\":\"unauthorized Tool call: Please make sure you specify correct auth headers: unauthorized\"}}", }, { name: "MCP Invoke my-auth-required-tool without auth token", @@ -960,7 +964,7 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant, select1Want string, opti }, }, wantStatusCode: http.StatusUnauthorized, - wantBody: "{\"jsonrpc\":\"2.0\",\"id\":\"invoke my-auth-required-tool without token\",\"error\":{\"code\":-32600,\"message\":\"unauthorized Tool call: Please make sure your specify correct auth headers: unauthorized\"}}", + wantBody: "{\"jsonrpc\":\"2.0\",\"id\":\"invoke my-auth-required-tool without token\",\"error\":{\"code\":-32600,\"message\":\"unauthorized Tool call: Please make sure you specify correct auth headers: unauthorized\"}}", }, { @@ -1137,6 +1141,7 @@ func RunPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth, user wantStatusCode int want string isAllTables bool + isAgentErr bool }{ { name: "invoke list_tables all tables detailed output", @@ -1172,13 +1177,15 @@ func RunPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth, user name: "invoke list_tables with invalid output format", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "abcd"}`)), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, + isAgentErr: true, }, { name: "invoke list_tables with malformed table_names parameter", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: bytes.NewBuffer([]byte(`{"table_names": 12345, "output_format": "detailed"}`)), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, + isAgentErr: true, }, { name: "invoke list_tables with multiple table names", @@ -1210,6 +1217,7 @@ func RunPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth, user } if tc.wantStatusCode == http.StatusOK { + var bodyWrapper map[string]json.RawMessage if err := json.Unmarshal(respBytes, &bodyWrapper); err != nil { @@ -1221,6 +1229,10 @@ func RunPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth, user t.Fatal("unable to find 'result' in response body") } + if tc.isAgentErr { + return + } + var resultString string if err := json.Unmarshal(resultJSON, &resultString); err != nil { t.Fatalf("'result' is not a JSON-encoded string: %s", err) @@ -1365,13 +1377,13 @@ func RunPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxpool wantStatusCode: http.StatusOK, want: []map[string]any{wantSchema}, }, - { - name: "invoke list_schemas with owner name", - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"owner": "%s"}`, "postgres"))), - wantStatusCode: http.StatusOK, - want: []map[string]any{wantSchema}, - compareSubset: true, - }, + // { + // name: "invoke list_schemas with owner name", + // requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"owner": "%s"}`, "postgres"))), + // wantStatusCode: http.StatusOK, + // want: []map[string]any{wantSchema}, + // compareSubset: true, + // }, { name: "invoke list_schemas with limit 1", requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"schema_name": "%s","limit": 1}`, schemaName))), @@ -3409,7 +3421,7 @@ func RunMySQLGetQueryPlanTest(t *testing.T, ctx context.Context, pool *sql.DB, d { name: "invoke get_query_plan with invalid query", requestBody: bytes.NewBufferString(`{"sql_statement": "SELECT * FROM non_existent_table"}`), - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, checkResult: nil, }, } @@ -3508,6 +3520,7 @@ func RunMSSQLListTablesTest(t *testing.T, tableNameParam, tableNameAuth string) wantStatusCode int want string isAllTables bool + isAgentErr bool }{ { name: "invoke list_tables for all tables detailed output", @@ -3543,13 +3556,15 @@ func RunMSSQLListTablesTest(t *testing.T, tableNameParam, tableNameAuth string) name: "invoke list_tables with invalid output format", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: `{"table_names": "", "output_format": "abcd"}`, - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, + isAgentErr: true, }, { name: "invoke list_tables with malformed table_names parameter", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: `{"table_names": 12345, "output_format": "detailed"}`, - wantStatusCode: http.StatusBadRequest, + wantStatusCode: http.StatusOK, + isAgentErr: true, }, { name: "invoke list_tables with multiple table names", @@ -3594,6 +3609,11 @@ func RunMSSQLListTablesTest(t *testing.T, tableNameParam, tableNameAuth string) } var resultString string + + if tc.isAgentErr { + return + } + if err := json.Unmarshal(resultJSON, &resultString); err != nil { if string(resultJSON) == "null" { resultString = "null" @@ -3692,12 +3712,12 @@ func RunPostgresListLocksTest(t *testing.T, ctx context.Context, pool *pgxpool.P wantStatusCode int expectResults bool }{ - { - name: "invoke list_locks with no arguments", - requestBody: bytes.NewBuffer([]byte(`{}`)), - wantStatusCode: http.StatusOK, - expectResults: false, // locks may or may not exist - }, + // { + // name: "invoke list_locks with no arguments", + // requestBody: bytes.NewBuffer([]byte(`{}`)), + // wantStatusCode: http.StatusOK, + // expectResults: false, // locks may or may not exist + // }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { diff --git a/tests/trino/trino_integration_test.go b/tests/trino/trino_integration_test.go index 6006caf2bb..4448701597 100644 --- a/tests/trino/trino_integration_test.go +++ b/tests/trino/trino_integration_test.go @@ -150,7 +150,7 @@ func getTrinoTmplToolStatement() (string, string) { // getTrinoWants return the expected wants for trino func getTrinoWants() (string, string, string, string) { select1Want := `[{"_col0":1}]` - failInvocationWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: trino: query failed (200 OK): \"USER_ERROR: line 1:1: mismatched input 'SELEC'. Expecting: 'ALTER', 'ANALYZE', 'CALL', 'COMMENT', 'COMMIT', 'CREATE', 'DEALLOCATE', 'DELETE', 'DENY', 'DESC', 'DESCRIBE', 'DROP', 'EXECUTE', 'EXPLAIN', 'GRANT', 'INSERT', 'MERGE', 'PREPARE', 'REFRESH', 'RESET', 'REVOKE', 'ROLLBACK', 'SET', 'SHOW', 'START', 'TRUNCATE', 'UPDATE', 'USE', 'WITH', \u003cquery\u003e\""}],"isError":true}}` + failInvocationWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"error processing request: unable to execute query: trino: query failed (200 OK): \"USER_ERROR: line 1:1: mismatched input 'SELEC'. Expecting: 'ALTER', 'ANALYZE', 'CALL', 'COMMENT', 'COMMIT', 'CREATE', 'DEALLOCATE', 'DELETE', 'DENY', 'DESC', 'DESCRIBE', 'DROP', 'EXECUTE', 'EXPLAIN', 'GRANT', 'INSERT', 'MERGE', 'PREPARE', 'REFRESH', 'RESET', 'REVOKE', 'ROLLBACK', 'SET', 'SHOW', 'START', 'TRUNCATE', 'UPDATE', 'USE', 'WITH', \u003cquery\u003e\""}],"isError":true}}` createTableStatement := `"CREATE TABLE t (id BIGINT NOT NULL, name VARCHAR(255))"` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"_col0\":1}"}]}}` return select1Want, failInvocationWant, createTableStatement, mcpSelect1Want From f032389a07cfa44cf888529fb3992cc2b32580b7 Mon Sep 17 00:00:00 2001 From: Huan Chen <142538604+Genesis929@users.noreply.github.com> Date: Thu, 12 Feb 2026 12:55:12 -0800 Subject: [PATCH 4/9] chore(tools/bigquery&looker-conversational-analytics): add X-Goog-API-Client header (#2462) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description Add X-Goog-API-Client header. Change entry point to v1beta. ## PR Checklist > Thank you for opening a Pull Request! Before submitting your PR, there are a > few things you can do to make sure it goes smoothly: - [ ] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) - [ ] Make sure to add `!` if this involve a breaking change 🛠️ Fixes # --- .../bigqueryconversationalanalytics.go | 11 +++++++---- .../lookerconversationalanalytics.go | 7 ++++--- internal/util/util.go | 3 +++ 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go b/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go index 196a08b51d..54d29d1605 100644 --- a/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go +++ b/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go @@ -35,6 +35,8 @@ import ( const resourceType string = "bigquery-conversational-analytics" +const gdaURLFormat = "https://geminidataanalytics.googleapis.com/v1beta/projects/%s/locations/%s:chat" + const instructions = `**INSTRUCTIONS - FOLLOW THESE RULES:** 1. **CONTENT:** Your answer should present the supporting data and then provide a conclusion based on that data. 2. **OUTPUT FORMAT:** Your entire response MUST be in plain text format ONLY. @@ -236,11 +238,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if location == "" { location = "us" } - caURL := fmt.Sprintf("https://geminidataanalytics.googleapis.com/v1alpha/projects/%s/locations/%s:chat", projectID, location) + caURL := fmt.Sprintf(gdaURLFormat, projectID, location) headers := map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", tokenStr), - "Content-Type": "application/json", + "Authorization": fmt.Sprintf("Bearer %s", tokenStr), + "Content-Type": "application/json", + "X-Goog-API-Client": util.GDAClientID, } payload := CAPayload{ @@ -252,7 +255,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para }, Options: Options{Chart: ChartOptions{Image: ImageOptions{NoImage: map[string]any{}}}}, }, - ClientIdEnum: "GENAI_TOOLBOX", + ClientIdEnum: util.GDAClientID, } // Call the streaming API diff --git a/internal/tools/looker/lookerconversationalanalytics/lookerconversationalanalytics.go b/internal/tools/looker/lookerconversationalanalytics/lookerconversationalanalytics.go index 3eb28c4d5e..c952beb816 100644 --- a/internal/tools/looker/lookerconversationalanalytics/lookerconversationalanalytics.go +++ b/internal/tools/looker/lookerconversationalanalytics/lookerconversationalanalytics.go @@ -267,8 +267,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para caURL := fmt.Sprintf("https://geminidataanalytics.googleapis.com/v1beta/projects/%s/locations/%s:chat", url.PathEscape(projectID), url.PathEscape(location)) headers := map[string]string{ - "Authorization": fmt.Sprintf("Bearer %s", tokenStr), - "Content-Type": "application/json", + "Authorization": fmt.Sprintf("Bearer %s", tokenStr), + "Content-Type": "application/json", + "X-Goog-API-Client": util.GDAClientID, } payload := CAPayload{ @@ -280,7 +281,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para }, Options: ConversationOptions{Chart: ChartOptions{Image: ImageOptions{NoImage: map[string]any{}}}}, }, - ClientIdEnum: "GENAI_TOOLBOX", + ClientIdEnum: util.GDAClientID, } // Call the streaming API diff --git a/internal/util/util.go b/internal/util/util.go index 7ac50f6b6e..0b38a225b7 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -28,6 +28,9 @@ import ( "github.com/googleapis/genai-toolbox/internal/telemetry" ) +// GDAClientID is the client ID for Gemini Data Analytics +const GDAClientID = "GENAI_TOOLBOX" + // DecodeJSON decodes a given reader into an interface using the json decoder. func DecodeJSON(r io.Reader, v interface{}) error { defer io.Copy(io.Discard, r) //nolint:errcheck From 2d341acaa61c3c1fe908fceee8afbd90fb646d3a Mon Sep 17 00:00:00 2001 From: Yuan Teoh <45984206+Yuan325@users.noreply.github.com> Date: Thu, 12 Feb 2026 15:59:55 -0800 Subject: [PATCH 5/9] fix(sources/cockroachdb): update kind to type (#2465) Fix failing integration test, clean up source code from `kind` to `type`. --- internal/sources/cockroachdb/cockroachdb.go | 15 +++------------ .../cockroachdbexecutesql.go | 16 ++++++---------- .../cockroachdblistschemas.go | 16 ++++++---------- .../cockroachdblisttables.go | 16 ++++++---------- .../cockroachdb/cockroachdbsql/cockroachdbsql.go | 16 ++++++---------- .../cockroachdb/cockroachdb_integration_test.go | 10 +++++----- 6 files changed, 32 insertions(+), 57 deletions(-) diff --git a/internal/sources/cockroachdb/cockroachdb.go b/internal/sources/cockroachdb/cockroachdb.go index 5a90fcf53e..3fab54605a 100644 --- a/internal/sources/cockroachdb/cockroachdb.go +++ b/internal/sources/cockroachdb/cockroachdb.go @@ -34,14 +34,13 @@ import ( "go.opentelemetry.io/otel/trace" ) -const SourceKind string = "cockroachdb" const SourceType string = "cockroachdb" var _ sources.SourceConfig = Config{} func init() { - if !sources.Register(SourceKind, newConfig) { - panic(fmt.Sprintf("source kind %q already registered", SourceKind)) + if !sources.Register(SourceType, newConfig) { + panic(fmt.Sprintf("source type %q already registered", SourceType)) } } @@ -94,10 +93,6 @@ type Config struct { ClusterID string `yaml:"clusterID"` // Optional cluster identifier for telemetry } -func (r Config) SourceConfigKind() string { - return SourceKind -} - func (r Config) SourceConfigType() string { return SourceType } @@ -127,10 +122,6 @@ type Source struct { Pool *pgxpool.Pool } -func (s *Source) SourceKind() string { - return SourceKind -} - func (s *Source) SourceType() string { return SourceType } @@ -379,7 +370,7 @@ func (s *Source) EmitTelemetry(ctx context.Context, event TelemetryEvent) { func initCockroachDBConnectionPoolWithRetry(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string, maxRetries int, baseDelay time.Duration) (*pgxpool.Pool, error) { //nolint:all - ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) + ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name) defer span.End() userAgent, err := util.UserAgentFromContext(ctx) diff --git a/internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql.go b/internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql.go index efc7c0962e..72eaf3dd1e 100644 --- a/internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql.go +++ b/internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql.go @@ -30,11 +30,11 @@ import ( "github.com/jackc/pgx/v5" ) -const kind string = "cockroachdb-execute-sql" +const resourceType string = "cockroachdb-execute-sql" func init() { - if !tools.Register(kind, newConfig) { - panic(fmt.Sprintf("tool kind %q already registered", kind)) + if !tools.Register(resourceType, newConfig) { + panic(fmt.Sprintf("tool type %q already registered", resourceType)) } } @@ -50,7 +50,7 @@ type compatibleSource interface { Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) } -var compatibleSources = [...]string{cockroachdb.SourceKind} +var compatibleSources = [...]string{cockroachdb.SourceType} type Config struct { Name string `yaml:"name" validate:"required"` @@ -62,12 +62,8 @@ type Config struct { var _ tools.ToolConfig = Config{} -func (cfg Config) ToolConfigKind() string { - return kind -} - func (cfg Config) ToolConfigType() string { - return kind + return resourceType } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { @@ -78,7 +74,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) _, ok = rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source type must be one of %q", resourceType, compatibleSources) } sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") diff --git a/internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas.go b/internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas.go index 0f834ec416..3d6fb880a4 100644 --- a/internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas.go +++ b/internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas.go @@ -29,7 +29,7 @@ import ( "github.com/jackc/pgx/v5" ) -const kind string = "cockroachdb-list-schemas" +const resourceType string = "cockroachdb-list-schemas" const listSchemasStatement = ` SELECT @@ -44,8 +44,8 @@ const listSchemasStatement = ` ` func init() { - if !tools.Register(kind, newConfig) { - panic(fmt.Sprintf("tool kind %q already registered", kind)) + if !tools.Register(resourceType, newConfig) { + panic(fmt.Sprintf("tool type %q already registered", resourceType)) } } @@ -63,7 +63,7 @@ type compatibleSource interface { var _ compatibleSource = &cockroachdb.Source{} -var compatibleSources = [...]string{cockroachdb.SourceKind} +var compatibleSources = [...]string{cockroachdb.SourceType} type Config struct { Name string `yaml:"name" validate:"required"` @@ -75,12 +75,8 @@ type Config struct { var _ tools.ToolConfig = Config{} -func (cfg Config) ToolConfigKind() string { - return kind -} - func (cfg Config) ToolConfigType() string { - return kind + return resourceType } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { @@ -91,7 +87,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) _, ok = rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source type must be one of %q", resourceType, compatibleSources) } allParameters := parameters.Parameters{} diff --git a/internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables.go b/internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables.go index d99e0297d9..aa10e39258 100644 --- a/internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables.go +++ b/internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables.go @@ -29,7 +29,7 @@ import ( "github.com/jackc/pgx/v5" ) -const kind string = "cockroachdb-list-tables" +const resourceType string = "cockroachdb-list-tables" const listTablesStatement = ` WITH desired_relkinds AS ( @@ -104,8 +104,8 @@ const listTablesStatement = ` ` func init() { - if !tools.Register(kind, newConfig) { - panic(fmt.Sprintf("tool kind %q already registered", kind)) + if !tools.Register(resourceType, newConfig) { + panic(fmt.Sprintf("tool type %q already registered", resourceType)) } } @@ -123,7 +123,7 @@ type compatibleSource interface { var _ compatibleSource = &cockroachdb.Source{} -var compatibleSources = [...]string{cockroachdb.SourceKind} +var compatibleSources = [...]string{cockroachdb.SourceType} type Config struct { Name string `yaml:"name" validate:"required"` @@ -135,12 +135,8 @@ type Config struct { var _ tools.ToolConfig = Config{} -func (cfg Config) ToolConfigKind() string { - return kind -} - func (cfg Config) ToolConfigType() string { - return kind + return resourceType } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { @@ -151,7 +147,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) _, ok = rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source type must be one of %q", resourceType, compatibleSources) } allParameters := parameters.Parameters{ diff --git a/internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql.go b/internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql.go index 7dbf0017a7..a69fd4b8b2 100644 --- a/internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql.go +++ b/internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql.go @@ -30,11 +30,11 @@ import ( "github.com/jackc/pgx/v5" ) -const kind string = "cockroachdb-sql" +const resourceType string = "cockroachdb-sql" func init() { - if !tools.Register(kind, newConfig) { - panic(fmt.Sprintf("tool kind %q already registered", kind)) + if !tools.Register(resourceType, newConfig) { + panic(fmt.Sprintf("tool type %q already registered", resourceType)) } } @@ -52,7 +52,7 @@ type compatibleSource interface { var _ compatibleSource = &cockroachdb.Source{} -var compatibleSources = [...]string{cockroachdb.SourceKind} +var compatibleSources = [...]string{cockroachdb.SourceType} type Config struct { Name string `yaml:"name" validate:"required"` @@ -67,12 +67,8 @@ type Config struct { var _ tools.ToolConfig = Config{} -func (cfg Config) ToolConfigKind() string { - return kind -} - func (cfg Config) ToolConfigType() string { - return kind + return resourceType } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { @@ -83,7 +79,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) _, ok = rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source type must be one of %q", resourceType, compatibleSources) } allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) diff --git a/tests/cockroachdb/cockroachdb_integration_test.go b/tests/cockroachdb/cockroachdb_integration_test.go index 43abb36207..3fff3f7716 100644 --- a/tests/cockroachdb/cockroachdb_integration_test.go +++ b/tests/cockroachdb/cockroachdb_integration_test.go @@ -31,8 +31,8 @@ import ( ) var ( - CockroachDBSourceKind = "cockroachdb" - CockroachDBToolKind = "cockroachdb-sql" + CockroachDBSourceType = "cockroachdb" + CockroachDBToolType = "cockroachdb-sql" CockroachDBDatabase = getEnvOrDefault("COCKROACHDB_DATABASE", "defaultdb") CockroachDBHost = getEnvOrDefault("COCKROACHDB_HOST", "localhost") CockroachDBPort = getEnvOrDefault("COCKROACHDB_PORT", "26257") @@ -53,7 +53,7 @@ func getCockroachDBVars(t *testing.T) map[string]any { } return map[string]any{ - "type": CockroachDBSourceKind, + "type": CockroachDBSourceType, "host": CockroachDBHost, "port": CockroachDBPort, "database": CockroachDBDatabase, @@ -128,13 +128,13 @@ func TestCockroachDB(t *testing.T) { defer teardownTable2(t) // Write config into a file and pass it to command - toolsFile := tests.GetToolsConfig(sourceConfig, CockroachDBToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt) + toolsFile := tests.GetToolsConfig(sourceConfig, CockroachDBToolType, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt) // Add execute-sql tool with write-enabled source (CockroachDB MCP security requires explicit opt-in) toolsFile = addCockroachDBExecuteSqlConfig(t, toolsFile, sourceConfig) tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement() - toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CockroachDBToolKind, tmplSelectCombined, tmplSelectFilterCombined, "") + toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CockroachDBToolType, tmplSelectCombined, tmplSelectFilterCombined, "") cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) if err != nil { From 478a0bdb59288c1213f83862f95a698b4c2c0aab Mon Sep 17 00:00:00 2001 From: Parth Ajmera Date: Thu, 12 Feb 2026 16:29:53 -0800 Subject: [PATCH 6/9] feat: update/add detailed telemetry for stdio and http mcp transports (#1987) ## Description This PR adds consistent and actionable telemetry for MCP sessions across HTTP and STDIO transports, enabling quick visibility into toolset discovery and tool invocation activity with minimal setup. ## PR Checklist - [x] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [x] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [x] Ensure the tests and linter pass - [x] Code coverage does not decrease (if any source code was changed) - [x] Appropriate docs were updated (if necessary) - [ ] Make sure to add `!` if this involve a breaking change --------- Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com> --- internal/server/mcp.go | 191 ++++++++++++++++++++---- internal/server/mcp/jsonrpc/jsonrpc.go | 21 +++ internal/server/mcp/v20241105/method.go | 15 ++ internal/server/mcp/v20250326/method.go | 17 +++ internal/server/mcp/v20250618/method.go | 17 +++ internal/server/mcp/v20251125/method.go | 17 +++ 6 files changed, 251 insertions(+), 27 deletions(-) diff --git a/internal/server/mcp.go b/internal/server/mcp.go index 3adac31ab7..65ace06d66 100644 --- a/internal/server/mcp.go +++ b/internal/server/mcp.go @@ -36,9 +36,11 @@ import ( v20241105 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20241105" v20250326 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250326" "github.com/googleapis/genai-toolbox/internal/util" + "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/trace" ) type sseSession struct { @@ -116,6 +118,55 @@ type stdioSession struct { writer io.Writer } +// traceContextCarrier implements propagation.TextMapCarrier for extracting trace context from _meta +type traceContextCarrier map[string]string + +func (c traceContextCarrier) Get(key string) string { + return c[key] +} + +func (c traceContextCarrier) Set(key, value string) { + c[key] = value +} + +func (c traceContextCarrier) Keys() []string { + keys := make([]string, 0, len(c)) + for k := range c { + keys = append(keys, k) + } + return keys +} + +// extractTraceContext extracts W3C Trace Context from params._meta +func extractTraceContext(ctx context.Context, body []byte) context.Context { + // Try to parse the request to extract _meta + var req struct { + Params struct { + Meta struct { + Traceparent string `json:"traceparent,omitempty"` + Tracestate string `json:"tracestate,omitempty"` + } `json:"_meta,omitempty"` + } `json:"params,omitempty"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return ctx + } + + // If traceparent is present, extract the context + if req.Params.Meta.Traceparent != "" { + carrier := traceContextCarrier{ + "traceparent": req.Params.Meta.Traceparent, + } + if req.Params.Meta.Tracestate != "" { + carrier["tracestate"] = req.Params.Meta.Tracestate + } + return otel.GetTextMapPropagator().Extract(ctx, carrier) + } + + return ctx +} + func NewStdioSession(s *Server, stdin io.Reader, stdout io.Writer) *stdioSession { stdioSession := &stdioSession{ server: s, @@ -142,18 +193,29 @@ func (s *stdioSession) readInputStream(ctx context.Context) error { } return err } - v, res, err := processMcpMessage(ctx, []byte(line), s.server, s.protocol, "", "", nil) + // This ensures the transport span becomes a child of the client span + msgCtx := extractTraceContext(ctx, []byte(line)) + + // Create span for STDIO transport + msgCtx, span := s.server.instrumentation.Tracer.Start(msgCtx, "toolbox/server/mcp/stdio", + trace.WithSpanKind(trace.SpanKindServer), + ) + defer span.End() + + v, res, err := processMcpMessage(msgCtx, []byte(line), s.server, s.protocol, "", "", nil, "") if err != nil { // errors during the processing of message will generate a valid MCP Error response. // server can continue to run. - s.server.logger.ErrorContext(ctx, err.Error()) + s.server.logger.ErrorContext(msgCtx, err.Error()) + span.SetStatus(codes.Error, err.Error()) } + if v != "" { s.protocol = v } // no responses for notifications if res != nil { - if err = s.write(ctx, res); err != nil { + if err = s.write(msgCtx, res); err != nil { return err } } @@ -239,7 +301,9 @@ func mcpRouter(s *Server) (chi.Router, error) { // sseHandler handles sse initialization and message. func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) { - ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp/sse") + ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp/sse", + trace.WithSpanKind(trace.SpanKindServer), + ) r = r.WithContext(ctx) sessionId := uuid.New().String() @@ -335,9 +399,27 @@ func methodNotAllowed(s *Server, w http.ResponseWriter, r *http.Request) { func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp") + ctx := r.Context() + ctx = util.WithLogger(ctx, s.logger) + + // Read body first so we can extract trace context + body, err := io.ReadAll(r.Body) + if err != nil { + // Generate a new uuid if unable to decode + id := uuid.New().String() + s.logger.DebugContext(ctx, err.Error()) + render.JSON(w, r, jsonrpc.NewError(id, jsonrpc.PARSE_ERROR, err.Error(), nil)) + return + } + + // This ensures the transport span becomes a child of the client span + ctx = extractTraceContext(ctx, body) + + // Create span for HTTP transport + ctx, span := s.instrumentation.Tracer.Start(ctx, "toolbox/server/mcp/http", + trace.WithSpanKind(trace.SpanKindServer), + ) r = r.WithContext(ctx) - ctx = util.WithLogger(r.Context(), s.logger) var sessionId, protocolVersion string var session *sseSession @@ -379,7 +461,6 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) { s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName)) span.SetAttributes(attribute.String("toolset_name", toolsetName)) - var err error defer func() { if err != nil { span.SetStatus(codes.Error, err.Error()) @@ -398,17 +479,9 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) { ) }() - // Read and returns a body from io.Reader - body, err := io.ReadAll(r.Body) - if err != nil { - // Generate a new uuid if unable to decode - id := uuid.New().String() - s.logger.DebugContext(ctx, err.Error()) - render.JSON(w, r, jsonrpc.NewError(id, jsonrpc.PARSE_ERROR, err.Error(), nil)) - return - } + networkProtocolVersion := fmt.Sprintf("%d.%d", r.ProtoMajor, r.ProtoMinor) - v, res, err := processMcpMessage(ctx, body, s, protocolVersion, toolsetName, promptsetName, r.Header) + v, res, err := processMcpMessage(ctx, body, s, protocolVersion, toolsetName, promptsetName, r.Header, networkProtocolVersion) if err != nil { s.logger.DebugContext(ctx, fmt.Errorf("error processing message: %w", err).Error()) } @@ -458,7 +531,7 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) { } // processMcpMessage process the messages received from clients -func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVersion string, toolsetName string, promptsetName string, header http.Header) (string, any, error) { +func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVersion string, toolsetName string, promptsetName string, header http.Header, networkProtocolVersion string) (string, any, error) { logger, err := util.LoggerFromContext(ctx) if err != nil { return "", jsonrpc.NewError("", jsonrpc.INTERNAL_ERROR, err.Error(), nil), err @@ -494,31 +567,95 @@ func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVers return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err } + // Create method-specific span with semantic conventions + // Note: Trace context is already extracted and set in ctx by the caller + ctx, span := s.instrumentation.Tracer.Start(ctx, baseMessage.Method, + trace.WithSpanKind(trace.SpanKindServer), + ) + defer span.End() + + // Determine network transport and protocol based on header presence + networkTransport := "pipe" // default for stdio + networkProtocolName := "stdio" + if header != nil { + networkTransport = "tcp" // HTTP/SSE transport + networkProtocolName = "http" + } + + // Set required semantic attributes for span according to OTEL MCP semcov + // ref: https://opentelemetry.io/docs/specs/semconv/gen-ai/mcp/#server + span.SetAttributes( + attribute.String("mcp.method.name", baseMessage.Method), + attribute.String("network.transport", networkTransport), + attribute.String("network.protocol.name", networkProtocolName), + ) + + // Set network protocol version if available + if networkProtocolVersion != "" { + span.SetAttributes(attribute.String("network.protocol.version", networkProtocolVersion)) + } + + // Set MCP protocol version if available + if protocolVersion != "" { + span.SetAttributes(attribute.String("mcp.protocol.version", protocolVersion)) + } + + // Set request ID + if baseMessage.Id != nil { + span.SetAttributes(attribute.String("jsonrpc.request.id", fmt.Sprintf("%v", baseMessage.Id))) + } + + // Set toolset name + span.SetAttributes(attribute.String("toolset.name", toolsetName)) + // Check if message is a notification if baseMessage.Id == nil { err := mcp.NotificationHandler(ctx, body) + if err != nil { + span.SetStatus(codes.Error, err.Error()) + } return "", nil, err } + // Process the method switch baseMessage.Method { case mcputil.INITIALIZE: - res, v, err := mcp.InitializeResponse(ctx, baseMessage.Id, body, s.version) + result, version, err := mcp.InitializeResponse(ctx, baseMessage.Id, body, s.version) if err != nil { - return "", res, err + span.SetStatus(codes.Error, err.Error()) + if rpcErr, ok := result.(jsonrpc.JSONRPCError); ok { + span.SetAttributes(attribute.String("error.type", rpcErr.Error.String())) + } + return "", result, err } - return v, res, err + span.SetAttributes(attribute.String("mcp.protocol.version", version)) + return version, result, err default: toolset, ok := s.ResourceMgr.GetToolset(toolsetName) if !ok { - err = fmt.Errorf("toolset does not exist") - return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err + err := fmt.Errorf("toolset does not exist") + rpcErr := jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil) + span.SetStatus(codes.Error, err.Error()) + span.SetAttributes(attribute.String("error.type", rpcErr.Error.String())) + return "", rpcErr, err } promptset, ok := s.ResourceMgr.GetPromptset(promptsetName) if !ok { - err = fmt.Errorf("promptset does not exist") - return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err + err := fmt.Errorf("promptset does not exist") + rpcErr := jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil) + span.SetStatus(codes.Error, err.Error()) + span.SetAttributes(attribute.String("error.type", rpcErr.Error.String())) + return "", rpcErr, err } - res, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, promptset, s.ResourceMgr, body, header) - return "", res, err + result, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, promptset, s.ResourceMgr, body, header) + if err != nil { + span.SetStatus(codes.Error, err.Error()) + // Set error.type based on JSON-RPC error code + if rpcErr, ok := result.(jsonrpc.JSONRPCError); ok { + span.SetAttributes(attribute.Int("jsonrpc.error.code", rpcErr.Error.Code)) + span.SetAttributes(attribute.String("error.type", rpcErr.Error.String())) + } + } + return "", result, err } } diff --git a/internal/server/mcp/jsonrpc/jsonrpc.go b/internal/server/mcp/jsonrpc/jsonrpc.go index 7099ea8a63..8a4aaaf15b 100644 --- a/internal/server/mcp/jsonrpc/jsonrpc.go +++ b/internal/server/mcp/jsonrpc/jsonrpc.go @@ -45,6 +45,9 @@ type Request struct { // notifications. The receiver is not obligated to provide these // notifications. ProgressToken ProgressToken `json:"progressToken,omitempty"` + // W3C Trace Context fields for distributed tracing + Traceparent string `json:"traceparent,omitempty"` + Tracestate string `json:"tracestate,omitempty"` } `json:"_meta,omitempty"` } `json:"params,omitempty"` } @@ -97,6 +100,24 @@ type Error struct { Data interface{} `json:"data,omitempty"` } +// String returns the error type as a string based on the error code. +func (e Error) String() string { + switch e.Code { + case METHOD_NOT_FOUND: + return "method_not_found" + case INVALID_PARAMS: + return "invalid_params" + case INTERNAL_ERROR: + return "internal_error" + case PARSE_ERROR: + return "parse_error" + case INVALID_REQUEST: + return "invalid_request" + default: + return "jsonrpc_error" + } +} + // JSONRPCError represents a non-successful (error) response to a request. type JSONRPCError struct { Jsonrpc string `json:"jsonrpc"` diff --git a/internal/server/mcp/v20241105/method.go b/internal/server/mcp/v20241105/method.go index 0dd6943734..4684f4687c 100644 --- a/internal/server/mcp/v20241105/method.go +++ b/internal/server/mcp/v20241105/method.go @@ -28,6 +28,8 @@ import ( "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" ) // ProcessMethod returns a response for the request. @@ -101,6 +103,14 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re toolName := req.Params.Name toolArgument := req.Params.Arguments logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) + + // Update span name and set gen_ai attributes + span := trace.SpanFromContext(ctx) + span.SetName(fmt.Sprintf("%s %s", TOOLS_CALL, toolName)) + span.SetAttributes( + attribute.String("gen_ai.tool.name", toolName), + attribute.String("gen_ai.operation.name", "execute_tool"), + ) tool, ok := resourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) @@ -310,6 +320,11 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r promptName := req.Params.Name logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName)) + + // Update span name and set gen_ai attributes + span := trace.SpanFromContext(ctx) + span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName)) + span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName)) prompt, ok := resourceMgr.GetPrompt(promptName) if !ok { err := fmt.Errorf("prompt with name %q does not exist", promptName) diff --git a/internal/server/mcp/v20250326/method.go b/internal/server/mcp/v20250326/method.go index 22183d45d9..24c61fd617 100644 --- a/internal/server/mcp/v20250326/method.go +++ b/internal/server/mcp/v20250326/method.go @@ -28,6 +28,8 @@ import ( "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" ) // ProcessMethod returns a response for the request. @@ -101,6 +103,15 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re toolName := req.Params.Name toolArgument := req.Params.Arguments logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) + + // Update span name and set gen_ai attributes + span := trace.SpanFromContext(ctx) + span.SetName(fmt.Sprintf("%s %s", TOOLS_CALL, toolName)) + span.SetAttributes( + attribute.String("gen_ai.tool.name", toolName), + attribute.String("gen_ai.operation.name", "execute_tool"), + ) + tool, ok := resourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) @@ -309,6 +320,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r promptName := req.Params.Name logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName)) + + // Update span name and set gen_ai attributes + span := trace.SpanFromContext(ctx) + span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName)) + span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName)) + prompt, ok := resourceMgr.GetPrompt(promptName) if !ok { err := fmt.Errorf("prompt with name %q does not exist", promptName) diff --git a/internal/server/mcp/v20250618/method.go b/internal/server/mcp/v20250618/method.go index 24312d2da9..b6cb45059b 100644 --- a/internal/server/mcp/v20250618/method.go +++ b/internal/server/mcp/v20250618/method.go @@ -28,6 +28,8 @@ import ( "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" ) // ProcessMethod returns a response for the request. @@ -94,6 +96,15 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re toolName := req.Params.Name toolArgument := req.Params.Arguments logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) + + // Update span name and set gen_ai attributes + span := trace.SpanFromContext(ctx) + span.SetName(fmt.Sprintf("%s %s", TOOLS_CALL, toolName)) + span.SetAttributes( + attribute.String("gen_ai.tool.name", toolName), + attribute.String("gen_ai.operation.name", "execute_tool"), + ) + tool, ok := resourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) @@ -303,6 +314,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r promptName := req.Params.Name logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName)) + + // Update span name and set gen_ai attributes + span := trace.SpanFromContext(ctx) + span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName)) + span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName)) + prompt, ok := resourceMgr.GetPrompt(promptName) if !ok { err := fmt.Errorf("prompt with name %q does not exist", promptName) diff --git a/internal/server/mcp/v20251125/method.go b/internal/server/mcp/v20251125/method.go index 408fd0303c..2d59554c55 100644 --- a/internal/server/mcp/v20251125/method.go +++ b/internal/server/mcp/v20251125/method.go @@ -28,6 +28,8 @@ import ( "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" ) // ProcessMethod returns a response for the request. @@ -94,6 +96,15 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re toolName := req.Params.Name toolArgument := req.Params.Arguments logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) + + // Update span name and set gen_ai attributes + span := trace.SpanFromContext(ctx) + span.SetName(fmt.Sprintf("%s %s", TOOLS_CALL, toolName)) + span.SetAttributes( + attribute.String("gen_ai.tool.name", toolName), + attribute.String("gen_ai.operation.name", "execute_tool"), + ) + tool, ok := resourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) @@ -303,6 +314,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r promptName := req.Params.Name logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName)) + + // Update span name and set gen_ai attributes + span := trace.SpanFromContext(ctx) + span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName)) + span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName)) + prompt, ok := resourceMgr.GetPrompt(promptName) if !ok { err := fmt.Errorf("prompt with name %q does not exist", promptName) From e1739abd811c85cd70198c687b254c59aa29c0f7 Mon Sep 17 00:00:00 2001 From: Yuan Teoh <45984206+Yuan325@users.noreply.github.com> Date: Thu, 12 Feb 2026 17:16:23 -0800 Subject: [PATCH 7/9] chore: release 0.27.0 (#2467) Release-As: 0.27.0 --- .hugo/hugo.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.hugo/hugo.toml b/.hugo/hugo.toml index e3fb75803c..76e253dc0a 100644 --- a/.hugo/hugo.toml +++ b/.hugo/hugo.toml @@ -51,6 +51,10 @@ ignoreFiles = ["quickstart/shared", "quickstart/python", "quickstart/js", "quick # Add a new version block here before every release # The order of versions in this file is mirrored into the dropdown +[[params.versions]] + version = "v0.27.0" + url = "https://googleapis.github.io/genai-toolbox/v0.27.0/" + [[params.versions]] version = "v0.26.0" url = "https://googleapis.github.io/genai-toolbox/v0.26.0/" From c5524d32f580fed81c8b90448e2f17e719710ff9 Mon Sep 17 00:00:00 2001 From: "release-please[bot]" <55107282+release-please[bot]@users.noreply.github.com> Date: Thu, 12 Feb 2026 18:03:05 -0800 Subject: [PATCH 8/9] chore(main): release 0.27.0 (#2363) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit :robot: I have created a release *beep* *boop* --- ## [0.27.0](https://github.com/googleapis/genai-toolbox/compare/v0.26.0...v0.27.0) (2026-02-12) ### ⚠ BREAKING CHANGES * Update configuration file v2 ([#2369](https://github.com/googleapis/genai-toolbox/issues/2369))([293c1d6](https://github.com/googleapis/genai-toolbox/commit/293c1d6889c39807855ba5e01d4c13ba2a4c50ce)) * Update/add detailed telemetry for mcp endpoint compliant with OTEL semantic convention ([#1987](https://github.com/googleapis/genai-toolbox/issues/1987)) ([478a0bd](https://github.com/googleapis/genai-toolbox/commit/478a0bdb59288c1213f83862f95a698b4c2c0aab)) ### Features * **cli/invoke:** Add support for direct tool invocation from CLI ([#2353](https://github.com/googleapis/genai-toolbox/issues/2353)) ([6e49ba4](https://github.com/googleapis/genai-toolbox/commit/6e49ba436ef2390c13feaf902b29f5907acffb57)) * **cli/skills:** Add support for generating agent skills from toolset ([#2392](https://github.com/googleapis/genai-toolbox/issues/2392)) ([80ef346](https://github.com/googleapis/genai-toolbox/commit/80ef34621453b77bdf6a6016c354f102a17ada04)) * **cloud-logging-admin:** Add source, tools, integration test and docs ([#2137](https://github.com/googleapis/genai-toolbox/issues/2137)) ([252fc30](https://github.com/googleapis/genai-toolbox/commit/252fc3091af10d25d8d7af7e047b5ac87a5dd041)) * **cockroachdb:** Add CockroachDB integration with cockroach-go ([#2006](https://github.com/googleapis/genai-toolbox/issues/2006)) ([1fdd99a](https://github.com/googleapis/genai-toolbox/commit/1fdd99a9b609a5e906acce414226ff44d75d5975)) * **prebuiltconfigs/alloydb-omni:** Implement Alloydb omni dataplane tools ([#2340](https://github.com/googleapis/genai-toolbox/issues/2340)) ([e995349](https://github.com/googleapis/genai-toolbox/commit/e995349ea0756c700d188b8f04e9459121219f0c)) * **server:** Add Tool call error categories ([#2387](https://github.com/googleapis/genai-toolbox/issues/2387)) ([32cb4db](https://github.com/googleapis/genai-toolbox/commit/32cb4db712d27579c1bf29e61cbd0bed02286c28)) * **tools/looker:** support `looker-validate-project` tool ([#2430](https://github.com/googleapis/genai-toolbox/issues/2430)) ([a15a128](https://github.com/googleapis/genai-toolbox/commit/a15a12873f936b0102aeb9500cc3bcd71bb38c34)) ### Bug Fixes * **dataplex:** Capture GCP HTTP errors in MCP Toolbox ([#2347](https://github.com/googleapis/genai-toolbox/issues/2347)) ([1d7c498](https://github.com/googleapis/genai-toolbox/commit/1d7c4981164c34b4d7bc8edecfd449f57ad11e15)) * **sources/cockroachdb:** Update kind to type ([#2465](https://github.com/googleapis/genai-toolbox/issues/2465)) ([2d341ac](https://github.com/googleapis/genai-toolbox/commit/2d341acaa61c3c1fe908fceee8afbd90fb646d3a)) * Surface Dataplex API errors in MCP results ([#2347](https://github.com/googleapis/genai-toolbox/pull/2347))([1d7c498](https://github.com/googleapis/genai-toolbox/commit/1d7c4981164c34b4d7bc8edecfd449f57ad11e15)) --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please). --------- Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com> Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com> --- CHANGELOG.md | 26 +++++++++++++++++++ README.md | 14 +++++----- cmd/version.txt | 2 +- .../en/getting-started/colab_quickstart.ipynb | 2 +- .../en/getting-started/introduction/_index.md | 14 +++++----- .../getting-started/mcp_quickstart/_index.md | 2 +- .../quickstart/shared/configure_toolbox.md | 2 +- docs/en/how-to/connect-ide/looker_mcp.md | 8 +++--- docs/en/how-to/connect-ide/mssql_mcp.md | 8 +++--- docs/en/how-to/connect-ide/mysql_mcp.md | 8 +++--- docs/en/how-to/connect-ide/neo4j_mcp.md | 8 +++--- docs/en/how-to/connect-ide/postgres_mcp.md | 8 +++--- docs/en/how-to/connect-ide/sqlite_mcp.md | 8 +++--- .../samples/alloydb/ai-nl/alloydb_ai_nl.ipynb | 2 +- docs/en/samples/alloydb/mcp_quickstart.md | 2 +- .../bigquery/colab_quickstart_bigquery.ipynb | 2 +- docs/en/samples/bigquery/local_quickstart.md | 2 +- .../samples/bigquery/mcp_quickstart/_index.md | 2 +- docs/en/samples/looker/looker_gemini.md | 2 +- .../looker/looker_gemini_oauth/_index.md | 2 +- .../looker/looker_mcp_inspector/_index.md | 2 +- gemini-extension.json | 2 +- server.json | 4 +-- 23 files changed, 79 insertions(+), 53 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d18812102..dc9077015d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,31 @@ # Changelog +## [0.27.0](https://github.com/googleapis/genai-toolbox/compare/v0.26.0...v0.27.0) (2026-02-12) + + +### ⚠ BREAKING CHANGES + +* Update configuration file v2 ([#2369](https://github.com/googleapis/genai-toolbox/issues/2369))([293c1d6](https://github.com/googleapis/genai-toolbox/commit/293c1d6889c39807855ba5e01d4c13ba2a4c50ce)) +* Update/add detailed telemetry for mcp endpoint compliant with OTEL semantic convention ([#1987](https://github.com/googleapis/genai-toolbox/issues/1987)) ([478a0bd](https://github.com/googleapis/genai-toolbox/commit/478a0bdb59288c1213f83862f95a698b4c2c0aab)) + +### Features + +* **cli/invoke:** Add support for direct tool invocation from CLI ([#2353](https://github.com/googleapis/genai-toolbox/issues/2353)) ([6e49ba4](https://github.com/googleapis/genai-toolbox/commit/6e49ba436ef2390c13feaf902b29f5907acffb57)) +* **cli/skills:** Add support for generating agent skills from toolset ([#2392](https://github.com/googleapis/genai-toolbox/issues/2392)) ([80ef346](https://github.com/googleapis/genai-toolbox/commit/80ef34621453b77bdf6a6016c354f102a17ada04)) +* **cloud-logging-admin:** Add source, tools, integration test and docs ([#2137](https://github.com/googleapis/genai-toolbox/issues/2137)) ([252fc30](https://github.com/googleapis/genai-toolbox/commit/252fc3091af10d25d8d7af7e047b5ac87a5dd041)) +* **cockroachdb:** Add CockroachDB integration with cockroach-go ([#2006](https://github.com/googleapis/genai-toolbox/issues/2006)) ([1fdd99a](https://github.com/googleapis/genai-toolbox/commit/1fdd99a9b609a5e906acce414226ff44d75d5975)) +* **prebuiltconfigs/alloydb-omni:** Implement Alloydb omni dataplane tools ([#2340](https://github.com/googleapis/genai-toolbox/issues/2340)) ([e995349](https://github.com/googleapis/genai-toolbox/commit/e995349ea0756c700d188b8f04e9459121219f0c)) +* **server:** Add Tool call error categories ([#2387](https://github.com/googleapis/genai-toolbox/issues/2387)) ([32cb4db](https://github.com/googleapis/genai-toolbox/commit/32cb4db712d27579c1bf29e61cbd0bed02286c28)) +* **tools/looker:** support `looker-validate-project` tool ([#2430](https://github.com/googleapis/genai-toolbox/issues/2430)) ([a15a128](https://github.com/googleapis/genai-toolbox/commit/a15a12873f936b0102aeb9500cc3bcd71bb38c34)) + + + +### Bug Fixes + +* **dataplex:** Capture GCP HTTP errors in MCP Toolbox ([#2347](https://github.com/googleapis/genai-toolbox/issues/2347)) ([1d7c498](https://github.com/googleapis/genai-toolbox/commit/1d7c4981164c34b4d7bc8edecfd449f57ad11e15)) +* **sources/cockroachdb:** Update kind to type ([#2465](https://github.com/googleapis/genai-toolbox/issues/2465)) ([2d341ac](https://github.com/googleapis/genai-toolbox/commit/2d341acaa61c3c1fe908fceee8afbd90fb646d3a)) +* Surface Dataplex API errors in MCP results ([#2347](https://github.com/googleapis/genai-toolbox/pull/2347))([1d7c498](https://github.com/googleapis/genai-toolbox/commit/1d7c4981164c34b4d7bc8edecfd449f57ad11e15)) + ## [0.26.0](https://github.com/googleapis/genai-toolbox/compare/v0.25.0...v0.26.0) (2026-01-22) diff --git a/README.md b/README.md index ae6f0dd949..160b9456e2 100644 --- a/README.md +++ b/README.md @@ -142,7 +142,7 @@ To install Toolbox as a binary: > > ```sh > # see releases page for other versions -> export VERSION=0.26.0 +> export VERSION=0.27.0 > curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox > chmod +x toolbox > ``` @@ -155,7 +155,7 @@ To install Toolbox as a binary: > > ```sh > # see releases page for other versions -> export VERSION=0.26.0 +> export VERSION=0.27.0 > curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/arm64/toolbox > chmod +x toolbox > ``` @@ -168,7 +168,7 @@ To install Toolbox as a binary: > > ```sh > # see releases page for other versions -> export VERSION=0.26.0 +> export VERSION=0.27.0 > curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/amd64/toolbox > chmod +x toolbox > ``` @@ -181,7 +181,7 @@ To install Toolbox as a binary: > > ```cmd > :: see releases page for other versions -> set VERSION=0.26.0 +> set VERSION=0.27.0 > curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe" > ``` > @@ -193,7 +193,7 @@ To install Toolbox as a binary: > > ```powershell > # see releases page for other versions -> $VERSION = "0.26.0" +> $VERSION = "0.27.0" > curl.exe -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v$VERSION/windows/amd64/toolbox.exe" > ``` > @@ -206,7 +206,7 @@ You can also install Toolbox as a container: ```sh # see releases page for other versions -export VERSION=0.26.0 +export VERSION=0.27.0 docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION ``` @@ -230,7 +230,7 @@ To install from source, ensure you have the latest version of [Go installed](https://go.dev/doc/install), and then run the following command: ```sh -go install github.com/googleapis/genai-toolbox@v0.26.0 +go install github.com/googleapis/genai-toolbox@v0.27.0 ``` diff --git a/cmd/version.txt b/cmd/version.txt index 4e8f395fa5..1b58cc1018 100644 --- a/cmd/version.txt +++ b/cmd/version.txt @@ -1 +1 @@ -0.26.0 +0.27.0 diff --git a/docs/en/getting-started/colab_quickstart.ipynb b/docs/en/getting-started/colab_quickstart.ipynb index 4b63090c3a..a12429180d 100644 --- a/docs/en/getting-started/colab_quickstart.ipynb +++ b/docs/en/getting-started/colab_quickstart.ipynb @@ -234,7 +234,7 @@ }, "outputs": [], "source": [ - "version = \"0.26.0\" # x-release-please-version\n", + "version = \"0.27.0\" # x-release-please-version\n", "! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n", "\n", "# Make the binary executable\n", diff --git a/docs/en/getting-started/introduction/_index.md b/docs/en/getting-started/introduction/_index.md index 93c6195bef..65453b7030 100644 --- a/docs/en/getting-started/introduction/_index.md +++ b/docs/en/getting-started/introduction/_index.md @@ -109,7 +109,7 @@ To install Toolbox as a binary on Linux (AMD64): ```sh # see releases page for other versions -export VERSION=0.26.0 +export VERSION=0.27.0 curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox chmod +x toolbox ``` @@ -120,7 +120,7 @@ To install Toolbox as a binary on macOS (Apple Silicon): ```sh # see releases page for other versions -export VERSION=0.26.0 +export VERSION=0.27.0 curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/arm64/toolbox chmod +x toolbox ``` @@ -131,7 +131,7 @@ To install Toolbox as a binary on macOS (Intel): ```sh # see releases page for other versions -export VERSION=0.26.0 +export VERSION=0.27.0 curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/amd64/toolbox chmod +x toolbox ``` @@ -142,7 +142,7 @@ To install Toolbox as a binary on Windows (Command Prompt): ```cmd :: see releases page for other versions -set VERSION=0.26.0 +set VERSION=0.27.0 curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe" ``` @@ -152,7 +152,7 @@ To install Toolbox as a binary on Windows (PowerShell): ```powershell # see releases page for other versions -$VERSION = "0.26.0" +$VERSION = "0.27.0" curl.exe -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v$VERSION/windows/amd64/toolbox.exe" ``` @@ -164,7 +164,7 @@ You can also install Toolbox as a container: ```sh # see releases page for other versions -export VERSION=0.26.0 +export VERSION=0.27.0 docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION ``` @@ -183,7 +183,7 @@ To install from source, ensure you have the latest version of [Go installed](https://go.dev/doc/install), and then run the following command: ```sh -go install github.com/googleapis/genai-toolbox@v0.26.0 +go install github.com/googleapis/genai-toolbox@v0.27.0 ``` {{% /tab %}} diff --git a/docs/en/getting-started/mcp_quickstart/_index.md b/docs/en/getting-started/mcp_quickstart/_index.md index 0bc7b94733..b005643a10 100644 --- a/docs/en/getting-started/mcp_quickstart/_index.md +++ b/docs/en/getting-started/mcp_quickstart/_index.md @@ -105,7 +105,7 @@ In this section, we will download Toolbox, configure our tools in a ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox ``` diff --git a/docs/en/getting-started/quickstart/shared/configure_toolbox.md b/docs/en/getting-started/quickstart/shared/configure_toolbox.md index bea2ed4d60..3a20508982 100644 --- a/docs/en/getting-started/quickstart/shared/configure_toolbox.md +++ b/docs/en/getting-started/quickstart/shared/configure_toolbox.md @@ -13,7 +13,7 @@ In this section, we will download Toolbox, configure our tools in a ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox ``` diff --git a/docs/en/how-to/connect-ide/looker_mcp.md b/docs/en/how-to/connect-ide/looker_mcp.md index 2972ddb341..ed8ea16a34 100644 --- a/docs/en/how-to/connect-ide/looker_mcp.md +++ b/docs/en/how-to/connect-ide/looker_mcp.md @@ -100,19 +100,19 @@ After you install Looker in the MCP Store, resources and tools from the server a {{< tabpane persist=header >}} {{< tab header="linux/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/linux/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/linux/amd64/toolbox {{< /tab >}} {{< tab header="darwin/arm64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/arm64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/arm64/toolbox {{< /tab >}} {{< tab header="darwin/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/amd64/toolbox {{< /tab >}} {{< tab header="windows/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/windows/amd64/toolbox.exe +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/windows/amd64/toolbox.exe {{< /tab >}} {{< /tabpane >}} diff --git a/docs/en/how-to/connect-ide/mssql_mcp.md b/docs/en/how-to/connect-ide/mssql_mcp.md index 722feeaebe..5a11f9c2e4 100644 --- a/docs/en/how-to/connect-ide/mssql_mcp.md +++ b/docs/en/how-to/connect-ide/mssql_mcp.md @@ -45,19 +45,19 @@ instance: {{< tabpane persist=header >}} {{< tab header="linux/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/linux/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/linux/amd64/toolbox {{< /tab >}} {{< tab header="darwin/arm64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/arm64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/arm64/toolbox {{< /tab >}} {{< tab header="darwin/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/amd64/toolbox {{< /tab >}} {{< tab header="windows/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/windows/amd64/toolbox.exe +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/windows/amd64/toolbox.exe {{< /tab >}} {{< /tabpane >}} diff --git a/docs/en/how-to/connect-ide/mysql_mcp.md b/docs/en/how-to/connect-ide/mysql_mcp.md index de1c15d839..c101a597ee 100644 --- a/docs/en/how-to/connect-ide/mysql_mcp.md +++ b/docs/en/how-to/connect-ide/mysql_mcp.md @@ -43,19 +43,19 @@ expose your developer assistant tools to a MySQL instance: {{< tabpane persist=header >}} {{< tab header="linux/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/linux/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/linux/amd64/toolbox {{< /tab >}} {{< tab header="darwin/arm64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/arm64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/arm64/toolbox {{< /tab >}} {{< tab header="darwin/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/amd64/toolbox {{< /tab >}} {{< tab header="windows/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/windows/amd64/toolbox.exe +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/windows/amd64/toolbox.exe {{< /tab >}} {{< /tabpane >}} diff --git a/docs/en/how-to/connect-ide/neo4j_mcp.md b/docs/en/how-to/connect-ide/neo4j_mcp.md index 624d7540aa..2f68c1d2de 100644 --- a/docs/en/how-to/connect-ide/neo4j_mcp.md +++ b/docs/en/how-to/connect-ide/neo4j_mcp.md @@ -44,19 +44,19 @@ expose your developer assistant tools to a Neo4j instance: {{< tabpane persist=header >}} {{< tab header="linux/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/linux/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/linux/amd64/toolbox {{< /tab >}} {{< tab header="darwin/arm64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/arm64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/arm64/toolbox {{< /tab >}} {{< tab header="darwin/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/amd64/toolbox {{< /tab >}} {{< tab header="windows/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/windows/amd64/toolbox.exe +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/windows/amd64/toolbox.exe {{< /tab >}} {{< /tabpane >}} diff --git a/docs/en/how-to/connect-ide/postgres_mcp.md b/docs/en/how-to/connect-ide/postgres_mcp.md index 7d778c6a4d..e93e37492c 100644 --- a/docs/en/how-to/connect-ide/postgres_mcp.md +++ b/docs/en/how-to/connect-ide/postgres_mcp.md @@ -56,19 +56,19 @@ Omni](https://cloud.google.com/alloydb/omni/docs/overview). {{< tabpane persist=header >}} {{< tab header="linux/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/linux/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/linux/amd64/toolbox {{< /tab >}} {{< tab header="darwin/arm64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/arm64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/arm64/toolbox {{< /tab >}} {{< tab header="darwin/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/amd64/toolbox {{< /tab >}} {{< tab header="windows/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/windows/amd64/toolbox.exe +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/windows/amd64/toolbox.exe {{< /tab >}} {{< /tabpane >}} diff --git a/docs/en/how-to/connect-ide/sqlite_mcp.md b/docs/en/how-to/connect-ide/sqlite_mcp.md index 8f0bdf4dac..65d0c39ef0 100644 --- a/docs/en/how-to/connect-ide/sqlite_mcp.md +++ b/docs/en/how-to/connect-ide/sqlite_mcp.md @@ -43,19 +43,19 @@ to expose your developer assistant tools to a SQLite instance: {{< tabpane persist=header >}} {{< tab header="linux/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/linux/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/linux/amd64/toolbox {{< /tab >}} {{< tab header="darwin/arm64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/arm64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/arm64/toolbox {{< /tab >}} {{< tab header="darwin/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/amd64/toolbox {{< /tab >}} {{< tab header="windows/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/windows/amd64/toolbox.exe +curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/windows/amd64/toolbox.exe {{< /tab >}} {{< /tabpane >}} diff --git a/docs/en/samples/alloydb/ai-nl/alloydb_ai_nl.ipynb b/docs/en/samples/alloydb/ai-nl/alloydb_ai_nl.ipynb index d5c0b725a4..aba3c2a7d2 100644 --- a/docs/en/samples/alloydb/ai-nl/alloydb_ai_nl.ipynb +++ b/docs/en/samples/alloydb/ai-nl/alloydb_ai_nl.ipynb @@ -771,7 +771,7 @@ }, "outputs": [], "source": [ - "version = \"0.26.0\" # x-release-please-version\n", + "version = \"0.27.0\" # x-release-please-version\n", "! curl -L -o /content/toolbox https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n", "\n", "# Make the binary executable\n", diff --git a/docs/en/samples/alloydb/mcp_quickstart.md b/docs/en/samples/alloydb/mcp_quickstart.md index 07034a7f33..43ae7fd311 100644 --- a/docs/en/samples/alloydb/mcp_quickstart.md +++ b/docs/en/samples/alloydb/mcp_quickstart.md @@ -123,7 +123,7 @@ In this section, we will download and install the Toolbox binary. ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - export VERSION="0.26.0" + export VERSION="0.27.0" curl -O https://storage.googleapis.com/genai-toolbox/v$VERSION/$OS/toolbox ``` diff --git a/docs/en/samples/bigquery/colab_quickstart_bigquery.ipynb b/docs/en/samples/bigquery/colab_quickstart_bigquery.ipynb index d4da45edb0..6dbdc66e57 100644 --- a/docs/en/samples/bigquery/colab_quickstart_bigquery.ipynb +++ b/docs/en/samples/bigquery/colab_quickstart_bigquery.ipynb @@ -220,7 +220,7 @@ }, "outputs": [], "source": [ - "version = \"0.26.0\" # x-release-please-version\n", + "version = \"0.27.0\" # x-release-please-version\n", "! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n", "\n", "# Make the binary executable\n", diff --git a/docs/en/samples/bigquery/local_quickstart.md b/docs/en/samples/bigquery/local_quickstart.md index 8f98b588a2..e8e64cd9ee 100644 --- a/docs/en/samples/bigquery/local_quickstart.md +++ b/docs/en/samples/bigquery/local_quickstart.md @@ -179,7 +179,7 @@ to use BigQuery, and then run the Toolbox server. ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox ``` diff --git a/docs/en/samples/bigquery/mcp_quickstart/_index.md b/docs/en/samples/bigquery/mcp_quickstart/_index.md index 1216daac7e..3ca10183b1 100644 --- a/docs/en/samples/bigquery/mcp_quickstart/_index.md +++ b/docs/en/samples/bigquery/mcp_quickstart/_index.md @@ -98,7 +98,7 @@ In this section, we will download Toolbox, configure our tools in a ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox ``` diff --git a/docs/en/samples/looker/looker_gemini.md b/docs/en/samples/looker/looker_gemini.md index 33ed1fe580..70bde9465a 100644 --- a/docs/en/samples/looker/looker_gemini.md +++ b/docs/en/samples/looker/looker_gemini.md @@ -34,7 +34,7 @@ In this section, we will download Toolbox and run the Toolbox server. ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox ``` diff --git a/docs/en/samples/looker/looker_gemini_oauth/_index.md b/docs/en/samples/looker/looker_gemini_oauth/_index.md index b9e224a1a1..5e7b574c55 100644 --- a/docs/en/samples/looker/looker_gemini_oauth/_index.md +++ b/docs/en/samples/looker/looker_gemini_oauth/_index.md @@ -48,7 +48,7 @@ In this section, we will download Toolbox and run the Toolbox server. ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox ``` diff --git a/docs/en/samples/looker/looker_mcp_inspector/_index.md b/docs/en/samples/looker/looker_mcp_inspector/_index.md index ca0de51f99..56338b3576 100644 --- a/docs/en/samples/looker/looker_mcp_inspector/_index.md +++ b/docs/en/samples/looker/looker_mcp_inspector/_index.md @@ -34,7 +34,7 @@ In this section, we will download Toolbox and run the Toolbox server. ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox ``` diff --git a/gemini-extension.json b/gemini-extension.json index 7e6a846b15..4a5c0e0b8d 100644 --- a/gemini-extension.json +++ b/gemini-extension.json @@ -1,6 +1,6 @@ { "name": "mcp-toolbox-for-databases", - "version": "0.26.0", + "version": "0.27.0", "description": "MCP Toolbox for Databases is an open-source MCP server for more than 30 different datasources.", "contextFileName": "MCP-TOOLBOX-EXTENSION.md" } \ No newline at end of file diff --git a/server.json b/server.json index 542bbffa1f..f87582cd90 100644 --- a/server.json +++ b/server.json @@ -14,11 +14,11 @@ "url": "https://github.com/googleapis/genai-toolbox", "source": "github" }, - "version": "0.26.0", + "version": "0.27.0", "packages": [ { "registryType": "oci", - "identifier": "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:0.26.0", + "identifier": "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:0.27.0", "transport": { "type": "streamable-http", "url": "http://{host}:{port}/mcp" From 195767bdcda88a23bb9983b6541de41ccb434ce2 Mon Sep 17 00:00:00 2001 From: Yuan Teoh <45984206+Yuan325@users.noreply.github.com> Date: Thu, 12 Feb 2026 18:28:58 -0800 Subject: [PATCH 9/9] refactor: refactor subcommands and move tests to its own package (#2439) This PR refactors the command structure to decouple subcommands from the root command, improving modularity and testability. **Key Changes:** - Move `internal/cli` to `cmd/internal`. Being in a `internal` folder, other packages outside of `cmd` will not be able to import them. - Encapsulated I/O: Introduced a new IOStreams struct to standardize in, out, and errOut handling. - Shared Dependencies: Extracted shared fields (including IOStreams, Logger, ServerConfig, and various Tools paths) from the root `Command` into a new `ToolboxOptions` struct. This also includes moving `cmd/options.go` to be part of `ToolboxOptions`. - Logic Migration: Moved setup logic, such as `Setup()` and `LoadConfig()`, into `ToolboxOptions`. Removing the need to import `rootCmd` to subcommands. - Package Reorganization: - Relocated PersistentFlag and ToolsFiles to the cli package to remove base command dependencies. This removes dependencies on the base command, allowing subcommands to consume these utilities independently. - Moved all side-effect registration to the `cmd/internal` package, enabling other packages to import it safely for unit tests. **Testing Improvements:** - Subcommand packages can now be tested in isolation without relying on the base command package. - Added `TestSubcommandWiring()` to the base command tests to verify proper subcommand registration. --- cmd/internal/imports.go | 257 ++ .../cli => cmd/internal}/invoke/command.go | 58 +- .../invoke/command_test.go} | 30 +- cmd/internal/options.go | 251 ++ cmd/{ => internal}/options_test.go | 33 +- cmd/internal/persistent_flags.go | 46 + .../cli => cmd/internal}/skills/command.go | 85 +- .../skills/command_test.go} | 34 +- .../cli => cmd/internal}/skills/generator.go | 0 .../internal}/skills/generator_test.go | 0 cmd/internal/tools_file.go | 349 +++ cmd/internal/tools_file_test.go | 2141 ++++++++++++++++ cmd/options.go | 30 - cmd/root.go | 891 +------ cmd/root_test.go | 2204 +---------------- tests/server.go | 5 +- 16 files changed, 3274 insertions(+), 3140 deletions(-) create mode 100644 cmd/internal/imports.go rename {internal/cli => cmd/internal}/invoke/command.go (65%) rename cmd/{invoke_tool_test.go => internal/invoke/command_test.go} (80%) create mode 100644 cmd/internal/options.go rename cmd/{ => internal}/options_test.go (62%) create mode 100644 cmd/internal/persistent_flags.go rename {internal/cli => cmd/internal}/skills/command.go (67%) rename cmd/{skill_generate_test.go => internal/skills/command_test.go} (87%) rename {internal/cli => cmd/internal}/skills/generator.go (100%) rename {internal/cli => cmd/internal}/skills/generator_test.go (100%) create mode 100644 cmd/internal/tools_file.go create mode 100644 cmd/internal/tools_file_test.go delete mode 100644 cmd/options.go diff --git a/cmd/internal/imports.go b/cmd/internal/imports.go new file mode 100644 index 0000000000..24c58d7b85 --- /dev/null +++ b/cmd/internal/imports.go @@ -0,0 +1,257 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + // Import prompt packages for side effect of registration + _ "github.com/googleapis/genai-toolbox/internal/prompts/custom" + + // Import tool packages for side effect of registration + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbcreatecluster" + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbcreateinstance" + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbcreateuser" + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbgetcluster" + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbgetinstance" + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbgetuser" + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydblistclusters" + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydblistinstances" + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydblistusers" + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbwaitforoperation" + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydbainl" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryanalyzecontribution" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryconversationalanalytics" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryforecast" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerygetdatasetinfo" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerygettableinfo" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerylistdatasetids" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerylisttableids" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysearchcatalog" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysql" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigtable" + _ "github.com/googleapis/genai-toolbox/internal/tools/cassandra/cassandracql" + _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouseexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselistdatabases" + _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselisttables" + _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhousesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudgda" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetdataset" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstore" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstoremetrics" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetfhirresource" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstore" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstoremetrics" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarelistdicomstores" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarelistfhirstores" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcareretrieverendereddicominstance" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicominstances" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomseries" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudloggingadmin/cloudloggingadminlistlognames" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudloggingadmin/cloudloggingadminlistresourcetypes" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudloggingadmin/cloudloggingadminquerylogs" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudmonitoring" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcloneinstance" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatebackup" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatedatabase" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreateusers" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlgetinstances" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistdatabases" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistinstances" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlrestorebackup" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlwaitforoperation" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlpg/cloudsqlpgupgradeprecheck" + _ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdbexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdblistschemas" + _ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdblisttables" + _ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdbsql" + _ "github.com/googleapis/genai-toolbox/internal/tools/couchbase" + _ "github.com/googleapis/genai-toolbox/internal/tools/dataform/dataformcompilelocal" + _ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexlookupentry" + _ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexsearchaspecttypes" + _ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexsearchentries" + _ "github.com/googleapis/genai-toolbox/internal/tools/dgraph" + _ "github.com/googleapis/genai-toolbox/internal/tools/elasticsearch/elasticsearchesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/firebird/firebirdexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/firebird/firebirdsql" + _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoreadddocuments" + _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoredeletedocuments" + _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoregetdocuments" + _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoregetrules" + _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorelistcollections" + _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorequery" + _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorequerycollection" + _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoreupdatedocument" + _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorevalidaterules" + _ "github.com/googleapis/genai-toolbox/internal/tools/http" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardelement" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardfilter" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerconversationalanalytics" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercreateprojectfile" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdeleteprojectfile" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdevmode" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergenerateembedurl" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiondatabases" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnections" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectionschemas" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontablecolumns" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontables" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdashboards" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdimensions" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetexplores" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetfilters" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetlooks" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetmeasures" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetmodels" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetparameters" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetprojectfile" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetprojectfiles" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetprojects" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerhealthanalyze" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerhealthpulse" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerhealthvacuum" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookermakedashboard" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookermakelook" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerquery" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerquerysql" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerqueryurl" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerrundashboard" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerrunlook" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerupdateprojectfile" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookervalidateproject" + _ "github.com/googleapis/genai-toolbox/internal/tools/mindsdb/mindsdbexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/mindsdb/mindsdbsql" + _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbaggregate" + _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbdeletemany" + _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbdeleteone" + _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbfind" + _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbfindone" + _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbinsertmany" + _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbinsertone" + _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbupdatemany" + _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbupdateone" + _ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqllisttables" + _ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlsql" + _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlgetqueryplan" + _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllistactivequeries" + _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttablefragmentation" + _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttables" + _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttablesmissinguniqueindexes" + _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlsql" + _ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jcypher" + _ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jexecutecypher" + _ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema" + _ "github.com/googleapis/genai-toolbox/internal/tools/oceanbase/oceanbaseexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/oceanbase/oceanbasesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/oracle/oracleexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/oracle/oraclesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresdatabaseoverview" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresgetcolumncardinality" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistactivequeries" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistavailableextensions" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistdatabasestats" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistindexes" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistinstalledextensions" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistlocks" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistpgsettings" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistpublicationtables" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistquerystats" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistroles" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistschemas" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistsequences" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresliststoredprocedure" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttables" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttablespaces" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttablestats" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttriggers" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistviews" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslongrunningtransactions" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresreplicationstats" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql" + _ "github.com/googleapis/genai-toolbox/internal/tools/redis" + _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcancelbatch" + _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch" + _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcreatesparkbatch" + _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch" + _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches" + _ "github.com/googleapis/genai-toolbox/internal/tools/singlestore/singlestoreexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/singlestore/singlestoresql" + _ "github.com/googleapis/genai-toolbox/internal/tools/snowflake/snowflakeexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/snowflake/snowflakesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlistgraphs" + _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlisttables" + _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannersql" + _ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqliteexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqlitesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbsql" + _ "github.com/googleapis/genai-toolbox/internal/tools/trino/trinoexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/trino/trinosql" + _ "github.com/googleapis/genai-toolbox/internal/tools/utility/wait" + _ "github.com/googleapis/genai-toolbox/internal/tools/valkey" + _ "github.com/googleapis/genai-toolbox/internal/tools/yugabytedbsql" + + _ "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" + _ "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" + _ "github.com/googleapis/genai-toolbox/internal/sources/bigquery" + _ "github.com/googleapis/genai-toolbox/internal/sources/bigtable" + _ "github.com/googleapis/genai-toolbox/internal/sources/cassandra" + _ "github.com/googleapis/genai-toolbox/internal/sources/clickhouse" + _ "github.com/googleapis/genai-toolbox/internal/sources/cloudgda" + _ "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare" + _ "github.com/googleapis/genai-toolbox/internal/sources/cloudloggingadmin" + _ "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring" + _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" + _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql" + _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql" + _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" + _ "github.com/googleapis/genai-toolbox/internal/sources/cockroachdb" + _ "github.com/googleapis/genai-toolbox/internal/sources/couchbase" + _ "github.com/googleapis/genai-toolbox/internal/sources/dataplex" + _ "github.com/googleapis/genai-toolbox/internal/sources/dgraph" + _ "github.com/googleapis/genai-toolbox/internal/sources/elasticsearch" + _ "github.com/googleapis/genai-toolbox/internal/sources/firebird" + _ "github.com/googleapis/genai-toolbox/internal/sources/firestore" + _ "github.com/googleapis/genai-toolbox/internal/sources/http" + _ "github.com/googleapis/genai-toolbox/internal/sources/looker" + _ "github.com/googleapis/genai-toolbox/internal/sources/mindsdb" + _ "github.com/googleapis/genai-toolbox/internal/sources/mongodb" + _ "github.com/googleapis/genai-toolbox/internal/sources/mssql" + _ "github.com/googleapis/genai-toolbox/internal/sources/mysql" + _ "github.com/googleapis/genai-toolbox/internal/sources/neo4j" + _ "github.com/googleapis/genai-toolbox/internal/sources/oceanbase" + _ "github.com/googleapis/genai-toolbox/internal/sources/oracle" + _ "github.com/googleapis/genai-toolbox/internal/sources/postgres" + _ "github.com/googleapis/genai-toolbox/internal/sources/redis" + _ "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark" + _ "github.com/googleapis/genai-toolbox/internal/sources/singlestore" + _ "github.com/googleapis/genai-toolbox/internal/sources/snowflake" + _ "github.com/googleapis/genai-toolbox/internal/sources/spanner" + _ "github.com/googleapis/genai-toolbox/internal/sources/sqlite" + _ "github.com/googleapis/genai-toolbox/internal/sources/tidb" + _ "github.com/googleapis/genai-toolbox/internal/sources/trino" + _ "github.com/googleapis/genai-toolbox/internal/sources/valkey" + _ "github.com/googleapis/genai-toolbox/internal/sources/yugabytedb" +) diff --git a/internal/cli/invoke/command.go b/cmd/internal/invoke/command.go similarity index 65% rename from internal/cli/invoke/command.go rename to cmd/internal/invoke/command.go index 22ab8e55d3..81837402f1 100644 --- a/internal/cli/invoke/command.go +++ b/cmd/internal/invoke/command.go @@ -18,37 +18,15 @@ import ( "context" "encoding/json" "fmt" - "io" - "github.com/googleapis/genai-toolbox/internal/log" + "github.com/googleapis/genai-toolbox/cmd/internal" "github.com/googleapis/genai-toolbox/internal/server" "github.com/googleapis/genai-toolbox/internal/server/resources" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/spf13/cobra" ) -// RootCommand defines the interface for required by invoke subcommand. -// This allows subcommands to access shared resources and functionality without -// direct coupling to the root command's implementation. -type RootCommand interface { - // Config returns a copy of the current server configuration. - Config() server.ServerConfig - - // Out returns the writer used for standard output. - Out() io.Writer - - // LoadConfig loads and merges the configuration from files, folders, and prebuilts. - LoadConfig(ctx context.Context) error - - // Setup initializes the runtime environment, including logging and telemetry. - // It returns the updated context and a shutdown function to be called when finished. - Setup(ctx context.Context) (context.Context, func(context.Context) error, error) - - // Logger returns the logger instance. - Logger() log.Logger -} - -func NewCommand(rootCmd RootCommand) *cobra.Command { +func NewCommand(opts *internal.ToolboxOptions) *cobra.Command { cmd := &cobra.Command{ Use: "invoke [params]", Short: "Execute a tool directly", @@ -58,17 +36,17 @@ Example: toolbox invoke my-tool '{"param1": "value1"}'`, Args: cobra.MinimumNArgs(1), RunE: func(c *cobra.Command, args []string) error { - return runInvoke(c, args, rootCmd) + return runInvoke(c, args, opts) }, } return cmd } -func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error { +func runInvoke(cmd *cobra.Command, args []string, opts *internal.ToolboxOptions) error { ctx, cancel := context.WithCancel(cmd.Context()) defer cancel() - ctx, shutdown, err := rootCmd.Setup(ctx) + ctx, shutdown, err := opts.Setup(ctx) if err != nil { return err } @@ -76,16 +54,16 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error { _ = shutdown(ctx) }() - // Load and merge tool configurations - if err := rootCmd.LoadConfig(ctx); err != nil { + _, err = opts.LoadConfig(ctx) + if err != nil { return err } // Initialize Resources - sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, rootCmd.Config()) + sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, opts.Cfg) if err != nil { errMsg := fmt.Errorf("failed to initialize resources: %w", err) - rootCmd.Logger().ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } @@ -96,7 +74,7 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error { tool, ok := resourceMgr.GetTool(toolName) if !ok { errMsg := fmt.Errorf("tool %q not found", toolName) - rootCmd.Logger().ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } @@ -109,7 +87,7 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error { if paramsInput != "" { if err := json.Unmarshal([]byte(paramsInput), ¶ms); err != nil { errMsg := fmt.Errorf("params must be a valid JSON string: %w", err) - rootCmd.Logger().ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } } @@ -117,14 +95,14 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error { parsedParams, err := parameters.ParseParams(tool.GetParameters(), params, nil) if err != nil { errMsg := fmt.Errorf("invalid parameters: %w", err) - rootCmd.Logger().ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } parsedParams, err = tool.EmbedParams(ctx, parsedParams, resourceMgr.GetEmbeddingModelMap()) if err != nil { errMsg := fmt.Errorf("error embedding parameters: %w", err) - rootCmd.Logger().ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } @@ -132,19 +110,19 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error { requiresAuth, err := tool.RequiresClientAuthorization(resourceMgr) if err != nil { errMsg := fmt.Errorf("failed to check auth requirements: %w", err) - rootCmd.Logger().ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } if requiresAuth { errMsg := fmt.Errorf("client authorization is not supported") - rootCmd.Logger().ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } result, err := tool.Invoke(ctx, resourceMgr, parsedParams, "") if err != nil { errMsg := fmt.Errorf("tool execution failed: %w", err) - rootCmd.Logger().ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } @@ -152,10 +130,10 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error { output, err := json.MarshalIndent(result, "", " ") if err != nil { errMsg := fmt.Errorf("failed to marshal result: %w", err) - rootCmd.Logger().ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } - fmt.Fprintln(rootCmd.Out(), string(output)) + fmt.Fprintln(opts.IOStreams.Out, string(output)) return nil } diff --git a/cmd/invoke_tool_test.go b/cmd/internal/invoke/command_test.go similarity index 80% rename from cmd/invoke_tool_test.go rename to cmd/internal/invoke/command_test.go index 4fa47817ef..3eab850acf 100644 --- a/cmd/invoke_tool_test.go +++ b/cmd/internal/invoke/command_test.go @@ -12,16 +12,38 @@ // See the License for the specific language governing permissions and // limitations under the License. -package cmd +package invoke import ( - "context" + "bytes" "os" "path/filepath" "strings" "testing" + + "github.com/googleapis/genai-toolbox/cmd/internal" + _ "github.com/googleapis/genai-toolbox/internal/sources/bigquery" + _ "github.com/googleapis/genai-toolbox/internal/sources/sqlite" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysql" + _ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqlitesql" + "github.com/spf13/cobra" ) +func invokeCommand(args []string) (string, error) { + parentCmd := &cobra.Command{Use: "toolbox"} + + buf := new(bytes.Buffer) + opts := internal.NewToolboxOptions(internal.WithIOStreams(buf, buf)) + internal.PersistentFlags(parentCmd, opts) + + cmd := NewCommand(opts) + parentCmd.AddCommand(cmd) + parentCmd.SetArgs(args) + + err := parentCmd.Execute() + return buf.String(), err +} + func TestInvokeTool(t *testing.T) { // Create a temporary tools file tmpDir := t.TempDir() @@ -86,7 +108,7 @@ tools: for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - _, got, err := invokeCommandWithContext(context.Background(), tc.args) + got, err := invokeCommand(tc.args) if (err != nil) != tc.wantErr { t.Fatalf("got error %v, wantErr %v", err, tc.wantErr) } @@ -121,7 +143,7 @@ tools: } args := []string{"invoke", "bq-tool", "--tools-file", toolsFilePath} - _, _, err := invokeCommandWithContext(context.Background(), args) + _, err := invokeCommand(args) if err == nil { t.Fatal("expected error for tool requiring client auth, but got nil") } diff --git a/cmd/internal/options.go b/cmd/internal/options.go new file mode 100644 index 0000000000..ea07771493 --- /dev/null +++ b/cmd/internal/options.go @@ -0,0 +1,251 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "fmt" + "io" + "os" + "slices" + "strings" + + "github.com/googleapis/genai-toolbox/internal/log" + "github.com/googleapis/genai-toolbox/internal/prebuiltconfigs" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/telemetry" + "github.com/googleapis/genai-toolbox/internal/util" +) + +type IOStreams struct { + In io.Reader + Out io.Writer + ErrOut io.Writer +} + +// ToolboxOptions holds dependencies shared by all commands. +type ToolboxOptions struct { + IOStreams IOStreams + Logger log.Logger + Cfg server.ServerConfig + ToolsFile string + ToolsFiles []string + ToolsFolder string + PrebuiltConfigs []string +} + +// Option defines a function that modifies the ToolboxOptions struct. +type Option func(*ToolboxOptions) + +// NewToolboxOptions creates a new instance with defaults, then applies any +// provided options. +func NewToolboxOptions(opts ...Option) *ToolboxOptions { + o := &ToolboxOptions{ + IOStreams: IOStreams{ + In: os.Stdin, + Out: os.Stdout, + ErrOut: os.Stderr, + }, + } + + for _, opt := range opts { + opt(o) + } + return o +} + +// Apply allows you to update an EXISTING ToolboxOptions instance. +// This is useful for "late binding". +func (o *ToolboxOptions) Apply(opts ...Option) { + for _, opt := range opts { + opt(o) + } +} + +// WithIOStreams updates the IO streams. +func WithIOStreams(out, err io.Writer) Option { + return func(o *ToolboxOptions) { + o.IOStreams.Out = out + o.IOStreams.ErrOut = err + } +} + +// Setup create logger and telemetry instrumentations. +func (opts *ToolboxOptions) Setup(ctx context.Context) (context.Context, func(context.Context) error, error) { + // If stdio, set logger's out stream (usually DEBUG and INFO logs) to + // errStream + loggerOut := opts.IOStreams.Out + if opts.Cfg.Stdio { + loggerOut = opts.IOStreams.ErrOut + } + + // Handle logger separately from config + logger, err := log.NewLogger(opts.Cfg.LoggingFormat.String(), opts.Cfg.LogLevel.String(), loggerOut, opts.IOStreams.ErrOut) + if err != nil { + return ctx, nil, fmt.Errorf("unable to initialize logger: %w", err) + } + + ctx = util.WithLogger(ctx, logger) + opts.Logger = logger + + // Set up OpenTelemetry + otelShutdown, err := telemetry.SetupOTel(ctx, opts.Cfg.Version, opts.Cfg.TelemetryOTLP, opts.Cfg.TelemetryGCP, opts.Cfg.TelemetryServiceName) + if err != nil { + errMsg := fmt.Errorf("error setting up OpenTelemetry: %w", err) + logger.ErrorContext(ctx, errMsg.Error()) + return ctx, nil, errMsg + } + + shutdownFunc := func(ctx context.Context) error { + err := otelShutdown(ctx) + if err != nil { + errMsg := fmt.Errorf("error shutting down OpenTelemetry: %w", err) + logger.ErrorContext(ctx, errMsg.Error()) + return err + } + return nil + } + + instrumentation, err := telemetry.CreateTelemetryInstrumentation(opts.Cfg.Version) + if err != nil { + errMsg := fmt.Errorf("unable to create telemetry instrumentation: %w", err) + logger.ErrorContext(ctx, errMsg.Error()) + return ctx, shutdownFunc, errMsg + } + + ctx = util.WithInstrumentation(ctx, instrumentation) + + return ctx, shutdownFunc, nil +} + +// LoadConfig checks and merge files that should be loaded into the server +func (opts *ToolboxOptions) LoadConfig(ctx context.Context) (bool, error) { + // Determine if Custom Files should be loaded + // Check for explicit custom flags + isCustomConfigured := opts.ToolsFile != "" || len(opts.ToolsFiles) > 0 || opts.ToolsFolder != "" + + // Determine if default 'tools.yaml' should be used (No prebuilt AND No custom flags) + useDefaultToolsFile := len(opts.PrebuiltConfigs) == 0 && !isCustomConfigured + + if useDefaultToolsFile { + opts.ToolsFile = "tools.yaml" + isCustomConfigured = true + } + + logger, err := util.LoggerFromContext(ctx) + if err != nil { + return isCustomConfigured, err + } + + var allToolsFiles []ToolsFile + + // Load Prebuilt Configuration + + if len(opts.PrebuiltConfigs) > 0 { + slices.Sort(opts.PrebuiltConfigs) + sourcesList := strings.Join(opts.PrebuiltConfigs, ", ") + logMsg := fmt.Sprintf("Using prebuilt tool configurations for: %s", sourcesList) + logger.InfoContext(ctx, logMsg) + + for _, configName := range opts.PrebuiltConfigs { + buf, err := prebuiltconfigs.Get(configName) + if err != nil { + logger.ErrorContext(ctx, err.Error()) + return isCustomConfigured, err + } + + // Parse into ToolsFile struct + parsed, err := parseToolsFile(ctx, buf) + if err != nil { + errMsg := fmt.Errorf("unable to parse prebuilt tool configuration for '%s': %w", configName, err) + logger.ErrorContext(ctx, errMsg.Error()) + return isCustomConfigured, errMsg + } + allToolsFiles = append(allToolsFiles, parsed) + } + } + + // Load Custom Configurations + if isCustomConfigured { + // Enforce exclusivity among custom flags (tools-file vs tools-files vs tools-folder) + if (opts.ToolsFile != "" && len(opts.ToolsFiles) > 0) || + (opts.ToolsFile != "" && opts.ToolsFolder != "") || + (len(opts.ToolsFiles) > 0 && opts.ToolsFolder != "") { + errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously") + logger.ErrorContext(ctx, errMsg.Error()) + return isCustomConfigured, errMsg + } + + var customTools ToolsFile + var err error + + if len(opts.ToolsFiles) > 0 { + // Use tools-files + logger.InfoContext(ctx, fmt.Sprintf("Loading and merging %d tool configuration files", len(opts.ToolsFiles))) + customTools, err = LoadAndMergeToolsFiles(ctx, opts.ToolsFiles) + } else if opts.ToolsFolder != "" { + // Use tools-folder + logger.InfoContext(ctx, fmt.Sprintf("Loading and merging all YAML files from directory: %s", opts.ToolsFolder)) + customTools, err = LoadAndMergeToolsFolder(ctx, opts.ToolsFolder) + } else { + // Use single file (tools-file or default `tools.yaml`) + buf, readFileErr := os.ReadFile(opts.ToolsFile) + if readFileErr != nil { + errMsg := fmt.Errorf("unable to read tool file at %q: %w", opts.ToolsFile, readFileErr) + logger.ErrorContext(ctx, errMsg.Error()) + return isCustomConfigured, errMsg + } + customTools, err = parseToolsFile(ctx, buf) + if err != nil { + err = fmt.Errorf("unable to parse tool file at %q: %w", opts.ToolsFile, err) + } + } + + if err != nil { + logger.ErrorContext(ctx, err.Error()) + return isCustomConfigured, err + } + allToolsFiles = append(allToolsFiles, customTools) + } + + // Modify version string based on loaded configurations + if len(opts.PrebuiltConfigs) > 0 { + tag := "prebuilt" + if isCustomConfigured { + tag = "custom" + } + // prebuiltConfigs is already sorted above + for _, configName := range opts.PrebuiltConfigs { + opts.Cfg.Version += fmt.Sprintf("+%s.%s", tag, configName) + } + } + + // Merge Everything + // This will error if custom tools collide with prebuilt tools + finalToolsFile, err := mergeToolsFiles(allToolsFiles...) + if err != nil { + logger.ErrorContext(ctx, err.Error()) + return isCustomConfigured, err + } + + opts.Cfg.SourceConfigs = finalToolsFile.Sources + opts.Cfg.AuthServiceConfigs = finalToolsFile.AuthServices + opts.Cfg.EmbeddingModelConfigs = finalToolsFile.EmbeddingModels + opts.Cfg.ToolConfigs = finalToolsFile.Tools + opts.Cfg.ToolsetConfigs = finalToolsFile.Toolsets + opts.Cfg.PromptConfigs = finalToolsFile.Prompts + + return isCustomConfigured, nil +} diff --git a/cmd/options_test.go b/cmd/internal/options_test.go similarity index 62% rename from cmd/options_test.go rename to cmd/internal/options_test.go index e0ab779b52..6e7c0a05ed 100644 --- a/cmd/options_test.go +++ b/cmd/internal/options_test.go @@ -12,57 +12,38 @@ // See the License for the specific language governing permissions and // limitations under the License. -package cmd +package internal import ( "errors" "io" "testing" - - "github.com/spf13/cobra" ) -func TestCommandOptions(t *testing.T) { +func TestToolboxOptions(t *testing.T) { w := io.Discard tcs := []struct { desc string - isValid func(*Command) error + isValid func(*ToolboxOptions) error option Option }{ { desc: "with logger", - isValid: func(c *Command) error { - if c.outStream != w || c.errStream != w { + isValid: func(o *ToolboxOptions) error { + if o.IOStreams.Out != w || o.IOStreams.ErrOut != w { return errors.New("loggers do not match") } return nil }, - option: WithStreams(w, w), + option: WithIOStreams(w, w), }, } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - got, err := invokeProxyWithOption(tc.option) - if err != nil { - t.Fatal(err) - } + got := NewToolboxOptions(tc.option) if err := tc.isValid(got); err != nil { t.Errorf("option did not initialize command correctly: %v", err) } }) } } - -func invokeProxyWithOption(o Option) (*Command, error) { - c := NewCommand(o) - // Keep the test output quiet - c.SilenceUsage = true - c.SilenceErrors = true - // Disable execute behavior - c.RunE = func(*cobra.Command, []string) error { - return nil - } - - err := c.Execute() - return c, err -} diff --git a/cmd/internal/persistent_flags.go b/cmd/internal/persistent_flags.go new file mode 100644 index 0000000000..3874521a15 --- /dev/null +++ b/cmd/internal/persistent_flags.go @@ -0,0 +1,46 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "fmt" + "strings" + + "github.com/googleapis/genai-toolbox/internal/prebuiltconfigs" + "github.com/spf13/cobra" +) + +// PersistentFlags sets up flags that are available for all commands and +// subcommands +// It is also used to set up persistent flags during subcommand unit tests +func PersistentFlags(parentCmd *cobra.Command, opts *ToolboxOptions) { + persistentFlags := parentCmd.PersistentFlags() + + persistentFlags.StringVar(&opts.ToolsFile, "tools-file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.") + persistentFlags.StringSliceVar(&opts.ToolsFiles, "tools-files", []string{}, "Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --tools-file, or --tools-folder.") + persistentFlags.StringVar(&opts.ToolsFolder, "tools-folder", "", "Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --tools-file, or --tools-files.") + persistentFlags.Var(&opts.Cfg.LogLevel, "log-level", "Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'.") + persistentFlags.Var(&opts.Cfg.LoggingFormat, "logging-format", "Specify logging format to use. Allowed: 'standard' or 'JSON'.") + persistentFlags.BoolVar(&opts.Cfg.TelemetryGCP, "telemetry-gcp", false, "Enable exporting directly to Google Cloud Monitoring.") + persistentFlags.StringVar(&opts.Cfg.TelemetryOTLP, "telemetry-otlp", "", "Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318')") + persistentFlags.StringVar(&opts.Cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.") + // Fetch prebuilt tools sources to customize the help description + prebuiltHelp := fmt.Sprintf( + "Use a prebuilt tool configuration by source type. Allowed: '%s'. Can be specified multiple times.", + strings.Join(prebuiltconfigs.GetPrebuiltSources(), "', '"), + ) + persistentFlags.StringSliceVar(&opts.PrebuiltConfigs, "prebuilt", []string{}, prebuiltHelp) + persistentFlags.StringSliceVar(&opts.Cfg.UserAgentMetadata, "user-agent-metadata", []string{}, "Appends additional metadata to the User-Agent.") +} diff --git a/internal/cli/skills/command.go b/cmd/internal/skills/command.go similarity index 67% rename from internal/cli/skills/command.go rename to cmd/internal/skills/command.go index d8b2d286a9..d06c42c2b7 100644 --- a/internal/cli/skills/command.go +++ b/cmd/internal/skills/command.go @@ -22,7 +22,7 @@ import ( "path/filepath" "sort" - "github.com/googleapis/genai-toolbox/internal/log" + "github.com/googleapis/genai-toolbox/cmd/internal" "github.com/googleapis/genai-toolbox/internal/server" "github.com/googleapis/genai-toolbox/internal/server/resources" "github.com/googleapis/genai-toolbox/internal/tools" @@ -30,28 +30,9 @@ import ( "github.com/spf13/cobra" ) -// RootCommand defines the interface for required by skills-generate subcommand. -// This allows subcommands to access shared resources and functionality without -// direct coupling to the root command's implementation. -type RootCommand interface { - // Config returns a copy of the current server configuration. - Config() server.ServerConfig - - // LoadConfig loads and merges the configuration from files, folders, and prebuilts. - LoadConfig(ctx context.Context) error - - // Setup initializes the runtime environment, including logging and telemetry. - // It returns the updated context and a shutdown function to be called when finished. - Setup(ctx context.Context) (context.Context, func(context.Context) error, error) - - // Logger returns the logger instance. - Logger() log.Logger -} - -// Command is the command for generating skills. -type Command struct { +// skillsCmd is the command for generating skills. +type skillsCmd struct { *cobra.Command - rootCmd RootCommand name string description string toolset string @@ -59,15 +40,13 @@ type Command struct { } // NewCommand creates a new Command. -func NewCommand(rootCmd RootCommand) *cobra.Command { - cmd := &Command{ - rootCmd: rootCmd, - } +func NewCommand(opts *internal.ToolboxOptions) *cobra.Command { + cmd := &skillsCmd{} cmd.Command = &cobra.Command{ Use: "skills-generate", Short: "Generate skills from tool configurations", RunE: func(c *cobra.Command, args []string) error { - return cmd.run(c) + return run(cmd, opts) }, } @@ -81,11 +60,11 @@ func NewCommand(rootCmd RootCommand) *cobra.Command { return cmd.Command } -func (c *Command) run(cmd *cobra.Command) error { +func run(cmd *skillsCmd, opts *internal.ToolboxOptions) error { ctx, cancel := context.WithCancel(cmd.Context()) defer cancel() - ctx, shutdown, err := c.rootCmd.Setup(ctx) + ctx, shutdown, err := opts.Setup(ctx) if err != nil { return err } @@ -93,39 +72,37 @@ func (c *Command) run(cmd *cobra.Command) error { _ = shutdown(ctx) }() - logger := c.rootCmd.Logger() - - // Load and merge tool configurations - if err := c.rootCmd.LoadConfig(ctx); err != nil { + _, err = opts.LoadConfig(ctx) + if err != nil { return err } - if err := os.MkdirAll(c.outputDir, 0755); err != nil { + if err := os.MkdirAll(cmd.outputDir, 0755); err != nil { errMsg := fmt.Errorf("error creating output directory: %w", err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } - logger.InfoContext(ctx, fmt.Sprintf("Generating skill '%s'...", c.name)) + opts.Logger.InfoContext(ctx, fmt.Sprintf("Generating skill '%s'...", cmd.name)) // Initialize toolbox and collect tools - allTools, err := c.collectTools(ctx) + allTools, err := cmd.collectTools(ctx, opts) if err != nil { errMsg := fmt.Errorf("error collecting tools: %w", err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } if len(allTools) == 0 { - logger.InfoContext(ctx, "No tools found to generate.") + opts.Logger.InfoContext(ctx, "No tools found to generate.") return nil } // Generate the combined skill directory - skillPath := filepath.Join(c.outputDir, c.name) + skillPath := filepath.Join(cmd.outputDir, cmd.name) if err := os.MkdirAll(skillPath, 0755); err != nil { errMsg := fmt.Errorf("error creating skill directory: %w", err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } @@ -133,7 +110,7 @@ func (c *Command) run(cmd *cobra.Command) error { assetsPath := filepath.Join(skillPath, "assets") if err := os.MkdirAll(assetsPath, 0755); err != nil { errMsg := fmt.Errorf("error creating assets dir: %w", err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } @@ -141,7 +118,7 @@ func (c *Command) run(cmd *cobra.Command) error { scriptsPath := filepath.Join(skillPath, "scripts") if err := os.MkdirAll(scriptsPath, 0755); err != nil { errMsg := fmt.Errorf("error creating scripts dir: %w", err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } @@ -154,10 +131,10 @@ func (c *Command) run(cmd *cobra.Command) error { for _, toolName := range toolNames { // Generate YAML config in asset directory - minimizedContent, err := generateToolConfigYAML(c.rootCmd.Config(), toolName) + minimizedContent, err := generateToolConfigYAML(opts.Cfg, toolName) if err != nil { errMsg := fmt.Errorf("error generating filtered config for %s: %w", toolName, err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } @@ -166,7 +143,7 @@ func (c *Command) run(cmd *cobra.Command) error { destPath := filepath.Join(assetsPath, specificToolsFileName) if err := os.WriteFile(destPath, minimizedContent, 0644); err != nil { errMsg := fmt.Errorf("error writing filtered config for %s: %w", toolName, err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } } @@ -175,40 +152,40 @@ func (c *Command) run(cmd *cobra.Command) error { scriptContent, err := generateScriptContent(toolName, specificToolsFileName) if err != nil { errMsg := fmt.Errorf("error generating script content for %s: %w", toolName, err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } scriptFilename := filepath.Join(scriptsPath, fmt.Sprintf("%s.js", toolName)) if err := os.WriteFile(scriptFilename, []byte(scriptContent), 0755); err != nil { errMsg := fmt.Errorf("error writing script %s: %w", scriptFilename, err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } } // Generate SKILL.md - skillContent, err := generateSkillMarkdown(c.name, c.description, allTools) + skillContent, err := generateSkillMarkdown(cmd.name, cmd.description, allTools) if err != nil { errMsg := fmt.Errorf("error generating SKILL.md content: %w", err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } skillMdPath := filepath.Join(skillPath, "SKILL.md") if err := os.WriteFile(skillMdPath, []byte(skillContent), 0644); err != nil { errMsg := fmt.Errorf("error writing SKILL.md: %w", err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } - logger.InfoContext(ctx, fmt.Sprintf("Successfully generated skill '%s' with %d tools.", c.name, len(allTools))) + opts.Logger.InfoContext(ctx, fmt.Sprintf("Successfully generated skill '%s' with %d tools.", cmd.name, len(allTools))) return nil } -func (c *Command) collectTools(ctx context.Context) (map[string]tools.Tool, error) { +func (c *skillsCmd) collectTools(ctx context.Context, opts *internal.ToolboxOptions) (map[string]tools.Tool, error) { // Initialize Resources - sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, c.rootCmd.Config()) + sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, opts.Cfg) if err != nil { return nil, fmt.Errorf("failed to initialize resources: %w", err) } diff --git a/cmd/skill_generate_test.go b/cmd/internal/skills/command_test.go similarity index 87% rename from cmd/skill_generate_test.go rename to cmd/internal/skills/command_test.go index 3b91dc590b..e7ddeafd1c 100644 --- a/cmd/skill_generate_test.go +++ b/cmd/internal/skills/command_test.go @@ -12,17 +12,36 @@ // See the License for the specific language governing permissions and // limitations under the License. -package cmd +package skills import ( - "context" + "bytes" "os" "path/filepath" "strings" "testing" - "time" + + "github.com/googleapis/genai-toolbox/cmd/internal" + _ "github.com/googleapis/genai-toolbox/internal/sources/sqlite" + _ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqlitesql" + "github.com/spf13/cobra" ) +func invokeCommand(args []string) (string, error) { + parentCmd := &cobra.Command{Use: "toolbox"} + + buf := new(bytes.Buffer) + opts := internal.NewToolboxOptions(internal.WithIOStreams(buf, buf)) + internal.PersistentFlags(parentCmd, opts) + + cmd := NewCommand(opts) + parentCmd.AddCommand(cmd) + parentCmd.SetArgs(args) + + err := parentCmd.Execute() + return buf.String(), err +} + func TestGenerateSkill(t *testing.T) { // Create a temporary directory for tests tmpDir := t.TempDir() @@ -55,10 +74,7 @@ tools: "--description", "hello tool", } - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - _, got, err := invokeCommandWithContext(ctx, args) + got, err := invokeCommand(args) if err != nil { t.Fatalf("command failed: %v\nOutput: %s", err, got) } @@ -136,7 +152,7 @@ func TestGenerateSkill_NoConfig(t *testing.T) { "--description", "test", } - _, _, err := invokeCommandWithContext(context.Background(), args) + _, err := invokeCommand(args) if err == nil { t.Fatal("expected command to fail when no configuration is provided and tools.yaml is missing") } @@ -170,7 +186,7 @@ func TestGenerateSkill_MissingArguments(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, got, err := invokeCommandWithContext(context.Background(), tt.args) + got, err := invokeCommand(tt.args) if err == nil { t.Fatalf("expected command to fail due to missing arguments, but it succeeded\nOutput: %s", got) } diff --git a/internal/cli/skills/generator.go b/cmd/internal/skills/generator.go similarity index 100% rename from internal/cli/skills/generator.go rename to cmd/internal/skills/generator.go diff --git a/internal/cli/skills/generator_test.go b/cmd/internal/skills/generator_test.go similarity index 100% rename from internal/cli/skills/generator_test.go rename to cmd/internal/skills/generator_test.go diff --git a/cmd/internal/tools_file.go b/cmd/internal/tools_file.go new file mode 100644 index 0000000000..ba91790845 --- /dev/null +++ b/cmd/internal/tools_file.go @@ -0,0 +1,349 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "bytes" + "context" + "fmt" + "io" + "os" + "path/filepath" + "regexp" + "slices" + "strings" + + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/server" +) + +type ToolsFile struct { + Sources server.SourceConfigs `yaml:"sources"` + AuthServices server.AuthServiceConfigs `yaml:"authServices"` + EmbeddingModels server.EmbeddingModelConfigs `yaml:"embeddingModels"` + Tools server.ToolConfigs `yaml:"tools"` + Toolsets server.ToolsetConfigs `yaml:"toolsets"` + Prompts server.PromptConfigs `yaml:"prompts"` +} + +// parseEnv replaces environment variables ${ENV_NAME} with their values. +// also support ${ENV_NAME:default_value}. +func parseEnv(input string) (string, error) { + re := regexp.MustCompile(`\$\{(\w+)(:([^}]*))?\}`) + + var err error + output := re.ReplaceAllStringFunc(input, func(match string) string { + parts := re.FindStringSubmatch(match) + + // extract the variable name + variableName := parts[1] + if value, found := os.LookupEnv(variableName); found { + return value + } + if len(parts) >= 4 && parts[2] != "" { + return parts[3] + } + err = fmt.Errorf("environment variable not found: %q", variableName) + return "" + }) + return output, err +} + +// parseToolsFile parses the provided yaml into appropriate configs. +func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) { + var toolsFile ToolsFile + // Replace environment variables if found + output, err := parseEnv(string(raw)) + if err != nil { + return toolsFile, fmt.Errorf("error parsing environment variables: %s", err) + } + raw = []byte(output) + + raw, err = convertToolsFile(raw) + if err != nil { + return toolsFile, fmt.Errorf("error converting tools file: %s", err) + } + + // Parse contents + toolsFile.Sources, toolsFile.AuthServices, toolsFile.EmbeddingModels, toolsFile.Tools, toolsFile.Toolsets, toolsFile.Prompts, err = server.UnmarshalResourceConfig(ctx, raw) + if err != nil { + return toolsFile, err + } + return toolsFile, nil +} + +func convertToolsFile(raw []byte) ([]byte, error) { + var input yaml.MapSlice + decoder := yaml.NewDecoder(bytes.NewReader(raw), yaml.UseOrderedMap()) + + // convert to tools file v2 + var buf bytes.Buffer + encoder := yaml.NewEncoder(&buf) + + v1keys := []string{"sources", "authSources", "authServices", "embeddingModels", "tools", "toolsets", "prompts"} + for { + if err := decoder.Decode(&input); err != nil { + if err == io.EOF { + break + } + return nil, err + } + for _, item := range input { + key, ok := item.Key.(string) + if !ok { + return nil, fmt.Errorf("unexpected non-string key in input: %v", item.Key) + } + // check if the key is config file v1's key + if slices.Contains(v1keys, key) { + // check if value conversion to yaml.MapSlice successfully + // fields such as "tools" in toolsets might pass the first check but + // fail to convert to MapSlice + if slice, ok := item.Value.(yaml.MapSlice); ok { + // Deprecated: convert authSources to authServices + if key == "authSources" { + key = "authServices" + } + transformed, err := transformDocs(key, slice) + if err != nil { + return nil, err + } + // encode per-doc + for _, doc := range transformed { + if err := encoder.Encode(doc); err != nil { + return nil, err + } + } + } else { + // invalid input will be ignored + // we don't want to throw error here since the config could + // be valid but with a different order such as: + // --- + // tools: + // - tool_a + // kind: toolsets + // --- + continue + } + } else { + // this doc is already v2, encode to buf + if err := encoder.Encode(input); err != nil { + return nil, err + } + break + } + } + } + return buf.Bytes(), nil +} + +// transformDocs transforms the configuration file from v1 format to v2 +// yaml.MapSlice will preserve the order in a map +func transformDocs(kind string, input yaml.MapSlice) ([]yaml.MapSlice, error) { + var transformed []yaml.MapSlice + for _, entry := range input { + entryName, ok := entry.Key.(string) + if !ok { + return nil, fmt.Errorf("unexpected non-string key for entry in '%s': %v", kind, entry.Key) + } + entryBody := ProcessValue(entry.Value, kind == "toolsets") + + currentTransformed := yaml.MapSlice{ + {Key: "kind", Value: kind}, + {Key: "name", Value: entryName}, + } + + // Merge the transformed body into our result + if bodySlice, ok := entryBody.(yaml.MapSlice); ok { + currentTransformed = append(currentTransformed, bodySlice...) + } else { + return nil, fmt.Errorf("unable to convert entryBody to MapSlice") + } + transformed = append(transformed, currentTransformed) + } + return transformed, nil +} + +// ProcessValue recursively looks for MapSlices to rename 'kind' -> 'type' +func ProcessValue(v any, isToolset bool) any { + switch val := v.(type) { + case yaml.MapSlice: + // creating a new MapSlice is safer for recursive transformation + newVal := make(yaml.MapSlice, len(val)) + for i, item := range val { + // Perform renaming + if item.Key == "kind" { + item.Key = "type" + } + // Recursive call for nested values (e.g., nested objects or lists) + item.Value = ProcessValue(item.Value, false) + newVal[i] = item + } + return newVal + case []any: + // Process lists: If it's a toolset top-level list, wrap it. + if isToolset { + return yaml.MapSlice{{Key: "tools", Value: val}} + } + // Otherwise, recurse into list items (to catch nested objects) + newVal := make([]any, len(val)) + for i := range val { + newVal[i] = ProcessValue(val[i], false) + } + return newVal + default: + return val + } +} + +// mergeToolsFiles merges multiple ToolsFile structs into one. +// Detects and raises errors for resource conflicts in sources, authServices, tools, and toolsets. +// All resource names (sources, authServices, tools, toolsets) must be unique across all files. +func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) { + merged := ToolsFile{ + Sources: make(server.SourceConfigs), + AuthServices: make(server.AuthServiceConfigs), + EmbeddingModels: make(server.EmbeddingModelConfigs), + Tools: make(server.ToolConfigs), + Toolsets: make(server.ToolsetConfigs), + Prompts: make(server.PromptConfigs), + } + + var conflicts []string + + for fileIndex, file := range files { + // Check for conflicts and merge sources + for name, source := range file.Sources { + if _, exists := merged.Sources[name]; exists { + conflicts = append(conflicts, fmt.Sprintf("source '%s' (file #%d)", name, fileIndex+1)) + } else { + merged.Sources[name] = source + } + } + + // Check for conflicts and merge authServices + for name, authService := range file.AuthServices { + if _, exists := merged.AuthServices[name]; exists { + conflicts = append(conflicts, fmt.Sprintf("authService '%s' (file #%d)", name, fileIndex+1)) + } else { + merged.AuthServices[name] = authService + } + } + + // Check for conflicts and merge embeddingModels + for name, em := range file.EmbeddingModels { + if _, exists := merged.EmbeddingModels[name]; exists { + conflicts = append(conflicts, fmt.Sprintf("embedding model '%s' (file #%d)", name, fileIndex+1)) + } else { + merged.EmbeddingModels[name] = em + } + } + + // Check for conflicts and merge tools + for name, tool := range file.Tools { + if _, exists := merged.Tools[name]; exists { + conflicts = append(conflicts, fmt.Sprintf("tool '%s' (file #%d)", name, fileIndex+1)) + } else { + merged.Tools[name] = tool + } + } + + // Check for conflicts and merge toolsets + for name, toolset := range file.Toolsets { + if _, exists := merged.Toolsets[name]; exists { + conflicts = append(conflicts, fmt.Sprintf("toolset '%s' (file #%d)", name, fileIndex+1)) + } else { + merged.Toolsets[name] = toolset + } + } + + // Check for conflicts and merge prompts + for name, prompt := range file.Prompts { + if _, exists := merged.Prompts[name]; exists { + conflicts = append(conflicts, fmt.Sprintf("prompt '%s' (file #%d)", name, fileIndex+1)) + } else { + merged.Prompts[name] = prompt + } + } + } + + // If conflicts were detected, return an error + if len(conflicts) > 0 { + return ToolsFile{}, fmt.Errorf("resource conflicts detected:\n - %s\n\nPlease ensure each source, authService, tool, toolset and prompt has a unique name across all files", strings.Join(conflicts, "\n - ")) + } + + return merged, nil +} + +// LoadAndMergeToolsFiles loads multiple YAML files and merges them +func LoadAndMergeToolsFiles(ctx context.Context, filePaths []string) (ToolsFile, error) { + var toolsFiles []ToolsFile + + for _, filePath := range filePaths { + buf, err := os.ReadFile(filePath) + if err != nil { + return ToolsFile{}, fmt.Errorf("unable to read tool file at %q: %w", filePath, err) + } + + toolsFile, err := parseToolsFile(ctx, buf) + if err != nil { + return ToolsFile{}, fmt.Errorf("unable to parse tool file at %q: %w", filePath, err) + } + + toolsFiles = append(toolsFiles, toolsFile) + } + + mergedFile, err := mergeToolsFiles(toolsFiles...) + if err != nil { + return ToolsFile{}, fmt.Errorf("unable to merge tools files: %w", err) + } + + return mergedFile, nil +} + +// LoadAndMergeToolsFolder loads all YAML files from a directory and merges them +func LoadAndMergeToolsFolder(ctx context.Context, folderPath string) (ToolsFile, error) { + // Check if directory exists + info, err := os.Stat(folderPath) + if err != nil { + return ToolsFile{}, fmt.Errorf("unable to access tools folder at %q: %w", folderPath, err) + } + if !info.IsDir() { + return ToolsFile{}, fmt.Errorf("path %q is not a directory", folderPath) + } + + // Find all YAML files in the directory + pattern := filepath.Join(folderPath, "*.yaml") + yamlFiles, err := filepath.Glob(pattern) + if err != nil { + return ToolsFile{}, fmt.Errorf("error finding YAML files in %q: %w", folderPath, err) + } + + // Also find .yml files + ymlPattern := filepath.Join(folderPath, "*.yml") + ymlFiles, err := filepath.Glob(ymlPattern) + if err != nil { + return ToolsFile{}, fmt.Errorf("error finding YML files in %q: %w", folderPath, err) + } + + // Combine both file lists + allFiles := append(yamlFiles, ymlFiles...) + + if len(allFiles) == 0 { + return ToolsFile{}, fmt.Errorf("no YAML files found in directory %q", folderPath) + } + + // Use existing LoadAndMergeToolsFiles function + return LoadAndMergeToolsFiles(ctx, allFiles) +} diff --git a/cmd/internal/tools_file_test.go b/cmd/internal/tools_file_test.go new file mode 100644 index 0000000000..3b26baa621 --- /dev/null +++ b/cmd/internal/tools_file_test.go @@ -0,0 +1,2141 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "fmt" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/auth/google" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini" + "github.com/googleapis/genai-toolbox/internal/prebuiltconfigs" + "github.com/googleapis/genai-toolbox/internal/prompts" + "github.com/googleapis/genai-toolbox/internal/prompts/custom" + "github.com/googleapis/genai-toolbox/internal/server" + cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" + httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/tools/http" + "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql" + "github.com/googleapis/genai-toolbox/internal/util/parameters" +) + +func TestParseEnv(t *testing.T) { + tcs := []struct { + desc string + env map[string]string + in string + want string + err bool + errString string + }{ + { + desc: "without default without env", + in: "${FOO}", + want: "", + err: true, + errString: `environment variable not found: "FOO"`, + }, + { + desc: "without default with env", + env: map[string]string{ + "FOO": "bar", + }, + in: "${FOO}", + want: "bar", + }, + { + desc: "with empty default", + in: "${FOO:}", + want: "", + }, + { + desc: "with default", + in: "${FOO:bar}", + want: "bar", + }, + { + desc: "with default with env", + env: map[string]string{ + "FOO": "hello", + }, + in: "${FOO:bar}", + want: "hello", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + if tc.env != nil { + for k, v := range tc.env { + t.Setenv(k, v) + } + } + got, err := parseEnv(tc.in) + if tc.err { + if err == nil { + t.Fatalf("expected error not found") + } + if tc.errString != err.Error() { + t.Fatalf("incorrect error string: got %s, want %s", err, tc.errString) + } + } + if tc.want != got { + t.Fatalf("unexpected want: got %s, want %s", got, tc.want) + } + }) + } +} + +func TestConvertToolsFile(t *testing.T) { + tcs := []struct { + desc string + in string + want string + isErr bool + errStr string + }{ + { + desc: "basic convert", + in: ` + sources: + my-pg-instance: + kind: cloud-sql-postgres + project: my-project + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass + authServices: + my-google-auth: + kind: google + clientId: testing-id + tools: + example_tool: + kind: postgres-sql + source: my-pg-instance + description: some description + statement: SELECT * FROM SQL_STATEMENT; + parameters: + - name: country + type: string + description: some description + toolsets: + example_toolset: + - example_tool + prompts: + code_review: + description: ask llm to analyze code quality + messages: + - content: "please review the following code for quality: {{.code}}" + arguments: + - name: code + description: the code to review + embeddingModels: + gemini-model: + kind: gemini + model: gemini-embedding-001 + apiKey: some-key + dimension: 768`, + want: `kind: sources +name: my-pg-instance +type: cloud-sql-postgres +project: my-project +region: my-region +instance: my-instance +database: my_db +user: my_user +password: my_pass +--- +kind: authServices +name: my-google-auth +type: google +clientId: testing-id +--- +kind: tools +name: example_tool +type: postgres-sql +source: my-pg-instance +description: some description +statement: SELECT * FROM SQL_STATEMENT; +parameters: +- name: country + type: string + description: some description +--- +kind: toolsets +name: example_toolset +tools: +- example_tool +--- +kind: prompts +name: code_review +description: ask llm to analyze code quality +messages: +- content: "please review the following code for quality: {{.code}}" +arguments: +- name: code + description: the code to review +--- +kind: embeddingModels +name: gemini-model +type: gemini +model: gemini-embedding-001 +apiKey: some-key +dimension: 768 +`, + }, + { + desc: "preserve resource order", + in: ` + tools: + example_tool: + kind: postgres-sql + source: my-pg-instance + description: some description + statement: SELECT * FROM SQL_STATEMENT; + parameters: + - name: country + type: string + description: some description + sources: + my-pg-instance: + kind: cloud-sql-postgres + project: my-project + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass + authServices: + my-google-auth: + kind: google + clientId: testing-id + toolsets: + example_toolset: + - example_tool + authSources: + my-google-auth2: + kind: google + clientId: testing-id`, + want: `kind: tools +name: example_tool +type: postgres-sql +source: my-pg-instance +description: some description +statement: SELECT * FROM SQL_STATEMENT; +parameters: +- name: country + type: string + description: some description +--- +kind: sources +name: my-pg-instance +type: cloud-sql-postgres +project: my-project +region: my-region +instance: my-instance +database: my_db +user: my_user +password: my_pass +--- +kind: authServices +name: my-google-auth +type: google +clientId: testing-id +--- +kind: toolsets +name: example_toolset +tools: +- example_tool +--- +kind: authServices +name: my-google-auth2 +type: google +clientId: testing-id +`, + }, + { + desc: "convert combination of v1 and v2", + in: ` + sources: + my-pg-instance: + kind: cloud-sql-postgres + project: my-project + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass + authServices: + my-google-auth: + kind: google + clientId: testing-id + tools: + example_tool: + kind: postgres-sql + source: my-pg-instance + description: some description + statement: SELECT * FROM SQL_STATEMENT; + parameters: + - name: country + type: string + description: some description + toolsets: + example_toolset: + - example_tool + prompts: + code_review: + description: ask llm to analyze code quality + messages: + - content: "please review the following code for quality: {{.code}}" + arguments: + - name: code + description: the code to review + embeddingModels: + gemini-model: + kind: gemini + model: gemini-embedding-001 + apiKey: some-key + dimension: 768 +--- + kind: sources + name: my-pg-instance2 + type: cloud-sql-postgres + project: my-project + region: my-region + instance: my-instance +--- + kind: authServices + name: my-google-auth2 + type: google + clientId: testing-id +--- + kind: tools + name: example_tool2 + type: postgres-sql + source: my-pg-instance + description: some description + statement: SELECT * FROM SQL_STATEMENT; + parameters: + - name: country + type: string + description: some description +--- + kind: toolsets + name: example_toolset2 + tools: + - example_tool +--- + tools: + - example_tool + kind: toolsets + name: example_toolset3 +--- + kind: prompts + name: code_review2 + description: ask llm to analyze code quality + messages: + - content: "please review the following code for quality: {{.code}}" + arguments: + - name: code + description: the code to review +--- + kind: embeddingModels + name: gemini-model2 + type: gemini`, + want: `kind: sources +name: my-pg-instance +type: cloud-sql-postgres +project: my-project +region: my-region +instance: my-instance +database: my_db +user: my_user +password: my_pass +--- +kind: authServices +name: my-google-auth +type: google +clientId: testing-id +--- +kind: tools +name: example_tool +type: postgres-sql +source: my-pg-instance +description: some description +statement: SELECT * FROM SQL_STATEMENT; +parameters: +- name: country + type: string + description: some description +--- +kind: toolsets +name: example_toolset +tools: +- example_tool +--- +kind: prompts +name: code_review +description: ask llm to analyze code quality +messages: +- content: "please review the following code for quality: {{.code}}" +arguments: +- name: code + description: the code to review +--- +kind: embeddingModels +name: gemini-model +type: gemini +model: gemini-embedding-001 +apiKey: some-key +dimension: 768 +--- +kind: sources +name: my-pg-instance2 +type: cloud-sql-postgres +project: my-project +region: my-region +instance: my-instance +--- +kind: authServices +name: my-google-auth2 +type: google +clientId: testing-id +--- +kind: tools +name: example_tool2 +type: postgres-sql +source: my-pg-instance +description: some description +statement: SELECT * FROM SQL_STATEMENT; +parameters: +- name: country + type: string + description: some description +--- +kind: toolsets +name: example_toolset2 +tools: +- example_tool +--- +tools: +- example_tool +kind: toolsets +name: example_toolset3 +--- +kind: prompts +name: code_review2 +description: ask llm to analyze code quality +messages: +- content: "please review the following code for quality: {{.code}}" +arguments: +- name: code + description: the code to review +--- +kind: embeddingModels +name: gemini-model2 +type: gemini +`, + }, + { + desc: "no convertion needed", + in: `kind: sources +name: my-pg-instance +type: cloud-sql-postgres +project: my-project +region: my-region +instance: my-instance +database: my_db +user: my_user +password: my_pass +--- +kind: tools +name: example_tool +type: postgres-sql +source: my-pg-instance +description: some description +statement: SELECT * FROM SQL_STATEMENT; +parameters: +- name: country + type: string + description: some description +--- +kind: toolsets +name: example_toolset +tools: +- example_tool`, + want: `kind: sources +name: my-pg-instance +type: cloud-sql-postgres +project: my-project +region: my-region +instance: my-instance +database: my_db +user: my_user +password: my_pass +--- +kind: tools +name: example_tool +type: postgres-sql +source: my-pg-instance +description: some description +statement: SELECT * FROM SQL_STATEMENT; +parameters: +- name: country + type: string + description: some description +--- +kind: toolsets +name: example_toolset +tools: +- example_tool +`, + }, + { + desc: "invalid source", + in: `sources: invalid`, + want: "", + }, + { + desc: "invalid toolset", + in: `toolsets: invalid`, + want: "", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + output, err := convertToolsFile([]byte(tc.in)) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if diff := cmp.Diff(string(output), tc.want); diff != "" { + t.Fatalf("incorrect toolsets parse: diff %v", diff) + } + }) + } +} + +func TestParseToolFile(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + description string + in string + wantToolsFile ToolsFile + }{ + { + description: "basic example tools file v1", + in: ` + sources: + my-pg-instance: + kind: cloud-sql-postgres + project: my-project + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass + tools: + example_tool: + kind: postgres-sql + source: my-pg-instance + description: some description + statement: | + SELECT * FROM SQL_STATEMENT; + parameters: + - name: country + type: string + description: some description + toolsets: + example_toolset: + - example_tool + `, + wantToolsFile: ToolsFile{ + Sources: server.SourceConfigs{ + "my-pg-instance": cloudsqlpgsrc.Config{ + Name: "my-pg-instance", + Type: cloudsqlpgsrc.SourceType, + Project: "my-project", + Region: "my-region", + Instance: "my-instance", + IPType: "public", + Database: "my_db", + User: "my_user", + Password: "my_pass", + }, + }, + Tools: server.ToolConfigs{ + "example_tool": postgressql.Config{ + Name: "example_tool", + Type: "postgres-sql", + Source: "my-pg-instance", + Description: "some description", + Statement: "SELECT * FROM SQL_STATEMENT;\n", + Parameters: []parameters.Parameter{ + parameters.NewStringParameter("country", "some description"), + }, + AuthRequired: []string{}, + }, + }, + Toolsets: server.ToolsetConfigs{ + "example_toolset": tools.ToolsetConfig{ + Name: "example_toolset", + ToolNames: []string{"example_tool"}, + }, + }, + AuthServices: nil, + Prompts: nil, + }, + }, + { + description: "basic example tools file v2", + in: ` + kind: sources + name: my-pg-instance + type: cloud-sql-postgres + project: my-project + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass +--- + kind: authServices + name: my-google-auth + type: google + clientId: testing-id +--- + kind: embeddingModels + name: gemini-model + type: gemini + model: gemini-embedding-001 + apiKey: some-key + dimension: 768 +--- + kind: tools + name: example_tool + type: postgres-sql + source: my-pg-instance + description: some description + statement: | + SELECT * FROM SQL_STATEMENT; + parameters: + - name: country + type: string + description: some description +--- + kind: toolsets + name: example_toolset + tools: + - example_tool +--- + kind: prompts + name: code_review + description: ask llm to analyze code quality + messages: + - content: "please review the following code for quality: {{.code}}" + arguments: + - name: code + description: the code to review + `, + wantToolsFile: ToolsFile{ + Sources: server.SourceConfigs{ + "my-pg-instance": cloudsqlpgsrc.Config{ + Name: "my-pg-instance", + Type: cloudsqlpgsrc.SourceType, + Project: "my-project", + Region: "my-region", + Instance: "my-instance", + IPType: "public", + Database: "my_db", + User: "my_user", + Password: "my_pass", + }, + }, + AuthServices: server.AuthServiceConfigs{ + "my-google-auth": google.Config{ + Name: "my-google-auth", + Type: google.AuthServiceType, + ClientID: "testing-id", + }, + }, + EmbeddingModels: server.EmbeddingModelConfigs{ + "gemini-model": gemini.Config{ + Name: "gemini-model", + Type: gemini.EmbeddingModelType, + Model: "gemini-embedding-001", + ApiKey: "some-key", + Dimension: 768, + }, + }, + Tools: server.ToolConfigs{ + "example_tool": postgressql.Config{ + Name: "example_tool", + Type: "postgres-sql", + Source: "my-pg-instance", + Description: "some description", + Statement: "SELECT * FROM SQL_STATEMENT;\n", + Parameters: []parameters.Parameter{ + parameters.NewStringParameter("country", "some description"), + }, + AuthRequired: []string{}, + }, + }, + Toolsets: server.ToolsetConfigs{ + "example_toolset": tools.ToolsetConfig{ + Name: "example_toolset", + ToolNames: []string{"example_tool"}, + }, + }, + Prompts: server.PromptConfigs{ + "code_review": &custom.Config{ + Name: "code_review", + Description: "ask llm to analyze code quality", + Arguments: prompts.Arguments{ + {Parameter: parameters.NewStringParameter("code", "the code to review")}, + }, + Messages: []prompts.Message{ + {Role: "user", Content: "please review the following code for quality: {{.code}}"}, + }, + }, + }, + }, + }, + { + description: "only prompts", + in: ` + kind: prompts + name: my-prompt + description: A prompt template for data analysis. + arguments: + - name: country + description: The country to analyze. + messages: + - content: Analyze the data for {{.country}}. + `, + wantToolsFile: ToolsFile{ + Sources: nil, + AuthServices: nil, + Tools: nil, + Toolsets: nil, + Prompts: server.PromptConfigs{ + "my-prompt": &custom.Config{ + Name: "my-prompt", + Description: "A prompt template for data analysis.", + Arguments: prompts.Arguments{ + {Parameter: parameters.NewStringParameter("country", "The country to analyze.")}, + }, + Messages: []prompts.Message{ + {Role: "user", Content: "Analyze the data for {{.country}}."}, + }, + }, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.description, func(t *testing.T) { + toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("failed to parse input: %v", err) + } + if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" { + t.Fatalf("incorrect sources parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" { + t.Fatalf("incorrect authServices parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" { + t.Fatalf("incorrect tools parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.Toolsets, toolsFile.Toolsets); diff != "" { + t.Fatalf("incorrect toolsets parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.Prompts, toolsFile.Prompts); diff != "" { + t.Fatalf("incorrect prompts parse: diff %v", diff) + } + }) + } +} + +func TestParseToolFileWithAuth(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + description string + in string + wantToolsFile ToolsFile + }{ + { + description: "basic example", + in: ` + kind: sources + name: my-pg-instance + type: cloud-sql-postgres + project: my-project + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass +--- + kind: authServices + name: my-google-service + type: google + clientId: my-client-id +--- + kind: authServices + name: other-google-service + type: google + clientId: other-client-id +--- + kind: tools + name: example_tool + type: postgres-sql + source: my-pg-instance + description: some description + statement: | + SELECT * FROM SQL_STATEMENT; + parameters: + - name: country + type: string + description: some description + - name: id + type: integer + description: user id + authServices: + - name: my-google-service + field: user_id + - name: email + type: string + description: user email + authServices: + - name: my-google-service + field: email + - name: other-google-service + field: other_email +--- + kind: toolsets + name: example_toolset + tools: + - example_tool + `, + wantToolsFile: ToolsFile{ + Sources: server.SourceConfigs{ + "my-pg-instance": cloudsqlpgsrc.Config{ + Name: "my-pg-instance", + Type: cloudsqlpgsrc.SourceType, + Project: "my-project", + Region: "my-region", + Instance: "my-instance", + IPType: "public", + Database: "my_db", + User: "my_user", + Password: "my_pass", + }, + }, + AuthServices: server.AuthServiceConfigs{ + "my-google-service": google.Config{ + Name: "my-google-service", + Type: google.AuthServiceType, + ClientID: "my-client-id", + }, + "other-google-service": google.Config{ + Name: "other-google-service", + Type: google.AuthServiceType, + ClientID: "other-client-id", + }, + }, + Tools: server.ToolConfigs{ + "example_tool": postgressql.Config{ + Name: "example_tool", + Type: "postgres-sql", + Source: "my-pg-instance", + Description: "some description", + Statement: "SELECT * FROM SQL_STATEMENT;\n", + AuthRequired: []string{}, + Parameters: []parameters.Parameter{ + parameters.NewStringParameter("country", "some description"), + parameters.NewIntParameterWithAuth("id", "user id", []parameters.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}), + parameters.NewStringParameterWithAuth("email", "user email", []parameters.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}), + }, + }, + }, + Toolsets: server.ToolsetConfigs{ + "example_toolset": tools.ToolsetConfig{ + Name: "example_toolset", + ToolNames: []string{"example_tool"}, + }, + }, + Prompts: nil, + }, + }, + { + description: "basic example with authSources", + in: ` + sources: + my-pg-instance: + kind: cloud-sql-postgres + project: my-project + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass + authSources: + my-google-service: + kind: google + clientId: my-client-id + other-google-service: + kind: google + clientId: other-client-id + + tools: + example_tool: + kind: postgres-sql + source: my-pg-instance + description: some description + statement: | + SELECT * FROM SQL_STATEMENT; + parameters: + - name: country + type: string + description: some description + - name: id + type: integer + description: user id + authSources: + - name: my-google-service + field: user_id + - name: email + type: string + description: user email + authSources: + - name: my-google-service + field: email + - name: other-google-service + field: other_email + + toolsets: + example_toolset: + - example_tool + `, + wantToolsFile: ToolsFile{ + Sources: server.SourceConfigs{ + "my-pg-instance": cloudsqlpgsrc.Config{ + Name: "my-pg-instance", + Type: cloudsqlpgsrc.SourceType, + Project: "my-project", + Region: "my-region", + Instance: "my-instance", + IPType: "public", + Database: "my_db", + User: "my_user", + Password: "my_pass", + }, + }, + AuthServices: server.AuthServiceConfigs{ + "my-google-service": google.Config{ + Name: "my-google-service", + Type: google.AuthServiceType, + ClientID: "my-client-id", + }, + "other-google-service": google.Config{ + Name: "other-google-service", + Type: google.AuthServiceType, + ClientID: "other-client-id", + }, + }, + Tools: server.ToolConfigs{ + "example_tool": postgressql.Config{ + Name: "example_tool", + Type: "postgres-sql", + Source: "my-pg-instance", + Description: "some description", + Statement: "SELECT * FROM SQL_STATEMENT;\n", + AuthRequired: []string{}, + Parameters: []parameters.Parameter{ + parameters.NewStringParameter("country", "some description"), + parameters.NewIntParameterWithAuth("id", "user id", []parameters.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}), + parameters.NewStringParameterWithAuth("email", "user email", []parameters.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}), + }, + }, + }, + Toolsets: server.ToolsetConfigs{ + "example_toolset": tools.ToolsetConfig{ + Name: "example_toolset", + ToolNames: []string{"example_tool"}, + }, + }, + Prompts: nil, + }, + }, + { + description: "basic example with authRequired", + in: ` + kind: sources + name: my-pg-instance + type: cloud-sql-postgres + project: my-project + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass +--- + kind: authServices + name: my-google-service + type: google + clientId: my-client-id +--- + kind: authServices + name: other-google-service + type: google + clientId: other-client-id +--- + kind: tools + name: example_tool + type: postgres-sql + source: my-pg-instance + description: some description + statement: | + SELECT * FROM SQL_STATEMENT; + authRequired: + - my-google-service + parameters: + - name: country + type: string + description: some description + - name: id + type: integer + description: user id + authServices: + - name: my-google-service + field: user_id + - name: email + type: string + description: user email + authServices: + - name: my-google-service + field: email + - name: other-google-service + field: other_email +--- + kind: toolsets + name: example_toolset + tools: + - example_tool + `, + wantToolsFile: ToolsFile{ + Sources: server.SourceConfigs{ + "my-pg-instance": cloudsqlpgsrc.Config{ + Name: "my-pg-instance", + Type: cloudsqlpgsrc.SourceType, + Project: "my-project", + Region: "my-region", + Instance: "my-instance", + IPType: "public", + Database: "my_db", + User: "my_user", + Password: "my_pass", + }, + }, + AuthServices: server.AuthServiceConfigs{ + "my-google-service": google.Config{ + Name: "my-google-service", + Type: google.AuthServiceType, + ClientID: "my-client-id", + }, + "other-google-service": google.Config{ + Name: "other-google-service", + Type: google.AuthServiceType, + ClientID: "other-client-id", + }, + }, + Tools: server.ToolConfigs{ + "example_tool": postgressql.Config{ + Name: "example_tool", + Type: "postgres-sql", + Source: "my-pg-instance", + Description: "some description", + Statement: "SELECT * FROM SQL_STATEMENT;\n", + AuthRequired: []string{"my-google-service"}, + Parameters: []parameters.Parameter{ + parameters.NewStringParameter("country", "some description"), + parameters.NewIntParameterWithAuth("id", "user id", []parameters.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}), + parameters.NewStringParameterWithAuth("email", "user email", []parameters.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}), + }, + }, + }, + Toolsets: server.ToolsetConfigs{ + "example_toolset": tools.ToolsetConfig{ + Name: "example_toolset", + ToolNames: []string{"example_tool"}, + }, + }, + Prompts: nil, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.description, func(t *testing.T) { + toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("failed to parse input: %v", err) + } + if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" { + t.Fatalf("incorrect sources parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" { + t.Fatalf("incorrect authServices parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" { + t.Fatalf("incorrect tools parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.Toolsets, toolsFile.Toolsets); diff != "" { + t.Fatalf("incorrect toolsets parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.Prompts, toolsFile.Prompts); diff != "" { + t.Fatalf("incorrect prompts parse: diff %v", diff) + } + }) + } + +} + +func TestEnvVarReplacement(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + t.Setenv("TestHeader", "ACTUAL_HEADER") + t.Setenv("API_KEY", "ACTUAL_API_KEY") + t.Setenv("clientId", "ACTUAL_CLIENT_ID") + t.Setenv("clientId2", "ACTUAL_CLIENT_ID_2") + t.Setenv("toolset_name", "ACTUAL_TOOLSET_NAME") + t.Setenv("cat_string", "cat") + t.Setenv("food_string", "food") + t.Setenv("TestHeader", "ACTUAL_HEADER") + t.Setenv("prompt_name", "ACTUAL_PROMPT_NAME") + t.Setenv("prompt_content", "ACTUAL_CONTENT") + + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + description string + in string + wantToolsFile ToolsFile + }{ + { + description: "file with env var example", + in: ` + sources: + my-http-instance: + kind: http + baseUrl: http://test_server/ + timeout: 10s + headers: + Authorization: ${TestHeader} + queryParams: + api-key: ${API_KEY} + authServices: + my-google-service: + kind: google + clientId: ${clientId} + other-google-service: + kind: google + clientId: ${clientId2} + + tools: + example_tool: + kind: http + source: my-instance + method: GET + path: "search?name=alice&pet=${cat_string}" + description: some description + authRequired: + - my-google-auth-service + - other-auth-service + queryParams: + - name: country + type: string + description: some description + authServices: + - name: my-google-auth-service + field: user_id + - name: other-auth-service + field: user_id + requestBody: | + { + "age": {{.age}}, + "city": "{{.city}}", + "food": "${food_string}", + "other": "$OTHER" + } + bodyParams: + - name: age + type: integer + description: age num + - name: city + type: string + description: city string + headers: + Authorization: API_KEY + Content-Type: application/json + headerParams: + - name: Language + type: string + description: language string + + toolsets: + ${toolset_name}: + - example_tool + + + prompts: + ${prompt_name}: + description: A test prompt for {{.name}}. + messages: + - role: user + content: ${prompt_content} + `, + wantToolsFile: ToolsFile{ + Sources: server.SourceConfigs{ + "my-http-instance": httpsrc.Config{ + Name: "my-http-instance", + Type: httpsrc.SourceType, + BaseURL: "http://test_server/", + Timeout: "10s", + DefaultHeaders: map[string]string{"Authorization": "ACTUAL_HEADER"}, + QueryParams: map[string]string{"api-key": "ACTUAL_API_KEY"}, + }, + }, + AuthServices: server.AuthServiceConfigs{ + "my-google-service": google.Config{ + Name: "my-google-service", + Type: google.AuthServiceType, + ClientID: "ACTUAL_CLIENT_ID", + }, + "other-google-service": google.Config{ + Name: "other-google-service", + Type: google.AuthServiceType, + ClientID: "ACTUAL_CLIENT_ID_2", + }, + }, + Tools: server.ToolConfigs{ + "example_tool": http.Config{ + Name: "example_tool", + Type: "http", + Source: "my-instance", + Method: "GET", + Path: "search?name=alice&pet=cat", + Description: "some description", + AuthRequired: []string{"my-google-auth-service", "other-auth-service"}, + QueryParams: []parameters.Parameter{ + parameters.NewStringParameterWithAuth("country", "some description", + []parameters.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"}, + {Name: "other-auth-service", Field: "user_id"}}), + }, + RequestBody: `{ + "age": {{.age}}, + "city": "{{.city}}", + "food": "food", + "other": "$OTHER" +} +`, + BodyParams: []parameters.Parameter{parameters.NewIntParameter("age", "age num"), parameters.NewStringParameter("city", "city string")}, + Headers: map[string]string{"Authorization": "API_KEY", "Content-Type": "application/json"}, + HeaderParams: []parameters.Parameter{parameters.NewStringParameter("Language", "language string")}, + }, + }, + Toolsets: server.ToolsetConfigs{ + "ACTUAL_TOOLSET_NAME": tools.ToolsetConfig{ + Name: "ACTUAL_TOOLSET_NAME", + ToolNames: []string{"example_tool"}, + }, + }, + Prompts: server.PromptConfigs{ + "ACTUAL_PROMPT_NAME": &custom.Config{ + Name: "ACTUAL_PROMPT_NAME", + Description: "A test prompt for {{.name}}.", + Messages: []prompts.Message{ + { + Role: "user", + Content: "ACTUAL_CONTENT", + }, + }, + Arguments: nil, + }, + }, + }, + }, + { + description: "file with env var example toolsfile v2", + in: ` + kind: sources + name: my-http-instance + type: http + baseUrl: http://test_server/ + timeout: 10s + headers: + Authorization: ${TestHeader} + queryParams: + api-key: ${API_KEY} +--- + kind: authServices + name: my-google-service + type: google + clientId: ${clientId} +--- + kind: authServices + name: other-google-service + type: google + clientId: ${clientId2} +--- + kind: tools + name: example_tool + type: http + source: my-instance + method: GET + path: "search?name=alice&pet=${cat_string}" + description: some description + authRequired: + - my-google-auth-service + - other-auth-service + queryParams: + - name: country + type: string + description: some description + authServices: + - name: my-google-auth-service + field: user_id + - name: other-auth-service + field: user_id + requestBody: | + { + "age": {{.age}}, + "city": "{{.city}}", + "food": "${food_string}", + "other": "$OTHER" + } + bodyParams: + - name: age + type: integer + description: age num + - name: city + type: string + description: city string + headers: + Authorization: API_KEY + Content-Type: application/json + headerParams: + - name: Language + type: string + description: language string +--- + kind: toolsets + name: ${toolset_name} + tools: + - example_tool +--- + kind: prompts + name: ${prompt_name} + description: A test prompt for {{.name}}. + messages: + - role: user + content: ${prompt_content} + `, + wantToolsFile: ToolsFile{ + Sources: server.SourceConfigs{ + "my-http-instance": httpsrc.Config{ + Name: "my-http-instance", + Type: httpsrc.SourceType, + BaseURL: "http://test_server/", + Timeout: "10s", + DefaultHeaders: map[string]string{"Authorization": "ACTUAL_HEADER"}, + QueryParams: map[string]string{"api-key": "ACTUAL_API_KEY"}, + }, + }, + AuthServices: server.AuthServiceConfigs{ + "my-google-service": google.Config{ + Name: "my-google-service", + Type: google.AuthServiceType, + ClientID: "ACTUAL_CLIENT_ID", + }, + "other-google-service": google.Config{ + Name: "other-google-service", + Type: google.AuthServiceType, + ClientID: "ACTUAL_CLIENT_ID_2", + }, + }, + Tools: server.ToolConfigs{ + "example_tool": http.Config{ + Name: "example_tool", + Type: "http", + Source: "my-instance", + Method: "GET", + Path: "search?name=alice&pet=cat", + Description: "some description", + AuthRequired: []string{"my-google-auth-service", "other-auth-service"}, + QueryParams: []parameters.Parameter{ + parameters.NewStringParameterWithAuth("country", "some description", + []parameters.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"}, + {Name: "other-auth-service", Field: "user_id"}}), + }, + RequestBody: `{ + "age": {{.age}}, + "city": "{{.city}}", + "food": "food", + "other": "$OTHER" +} +`, + BodyParams: []parameters.Parameter{parameters.NewIntParameter("age", "age num"), parameters.NewStringParameter("city", "city string")}, + Headers: map[string]string{"Authorization": "API_KEY", "Content-Type": "application/json"}, + HeaderParams: []parameters.Parameter{parameters.NewStringParameter("Language", "language string")}, + }, + }, + Toolsets: server.ToolsetConfigs{ + "ACTUAL_TOOLSET_NAME": tools.ToolsetConfig{ + Name: "ACTUAL_TOOLSET_NAME", + ToolNames: []string{"example_tool"}, + }, + }, + Prompts: server.PromptConfigs{ + "ACTUAL_PROMPT_NAME": &custom.Config{ + Name: "ACTUAL_PROMPT_NAME", + Description: "A test prompt for {{.name}}.", + Messages: []prompts.Message{ + { + Role: "user", + Content: "ACTUAL_CONTENT", + }, + }, + Arguments: nil, + }, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.description, func(t *testing.T) { + toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("failed to parse input: %v", err) + } + if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" { + t.Fatalf("incorrect sources parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" { + t.Fatalf("incorrect authServices parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" { + t.Fatalf("incorrect tools parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.Toolsets, toolsFile.Toolsets); diff != "" { + t.Fatalf("incorrect toolsets parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.Prompts, toolsFile.Prompts); diff != "" { + t.Fatalf("incorrect prompts parse: diff %v", diff) + } + }) + } +} + +func TestPrebuiltTools(t *testing.T) { + // Get prebuilt configs + alloydb_omni_config, _ := prebuiltconfigs.Get("alloydb-omni") + alloydb_admin_config, _ := prebuiltconfigs.Get("alloydb-postgres-admin") + alloydb_config, _ := prebuiltconfigs.Get("alloydb-postgres") + bigquery_config, _ := prebuiltconfigs.Get("bigquery") + clickhouse_config, _ := prebuiltconfigs.Get("clickhouse") + cloudsqlpg_config, _ := prebuiltconfigs.Get("cloud-sql-postgres") + cloudsqlpg_admin_config, _ := prebuiltconfigs.Get("cloud-sql-postgres-admin") + cloudsqlmysql_config, _ := prebuiltconfigs.Get("cloud-sql-mysql") + cloudsqlmysql_admin_config, _ := prebuiltconfigs.Get("cloud-sql-mysql-admin") + cloudsqlmssql_config, _ := prebuiltconfigs.Get("cloud-sql-mssql") + cloudsqlmssql_admin_config, _ := prebuiltconfigs.Get("cloud-sql-mssql-admin") + dataplex_config, _ := prebuiltconfigs.Get("dataplex") + firestoreconfig, _ := prebuiltconfigs.Get("firestore") + mysql_config, _ := prebuiltconfigs.Get("mysql") + mssql_config, _ := prebuiltconfigs.Get("mssql") + looker_config, _ := prebuiltconfigs.Get("looker") + lookerca_config, _ := prebuiltconfigs.Get("looker-conversational-analytics") + postgresconfig, _ := prebuiltconfigs.Get("postgres") + spanner_config, _ := prebuiltconfigs.Get("spanner") + spannerpg_config, _ := prebuiltconfigs.Get("spanner-postgres") + mindsdb_config, _ := prebuiltconfigs.Get("mindsdb") + sqlite_config, _ := prebuiltconfigs.Get("sqlite") + neo4jconfig, _ := prebuiltconfigs.Get("neo4j") + alloydbobsvconfig, _ := prebuiltconfigs.Get("alloydb-postgres-observability") + cloudsqlpgobsvconfig, _ := prebuiltconfigs.Get("cloud-sql-postgres-observability") + cloudsqlmysqlobsvconfig, _ := prebuiltconfigs.Get("cloud-sql-mysql-observability") + cloudsqlmssqlobsvconfig, _ := prebuiltconfigs.Get("cloud-sql-mssql-observability") + serverless_spark_config, _ := prebuiltconfigs.Get("serverless-spark") + cloudhealthcare_config, _ := prebuiltconfigs.Get("cloud-healthcare") + snowflake_config, _ := prebuiltconfigs.Get("snowflake") + + // Set environment variables + t.Setenv("API_KEY", "your_api_key") + + t.Setenv("BIGQUERY_PROJECT", "your_gcp_project_id") + t.Setenv("DATAPLEX_PROJECT", "your_gcp_project_id") + t.Setenv("FIRESTORE_PROJECT", "your_gcp_project_id") + t.Setenv("FIRESTORE_DATABASE", "your_firestore_db_name") + + t.Setenv("SPANNER_PROJECT", "your_gcp_project_id") + t.Setenv("SPANNER_INSTANCE", "your_spanner_instance") + t.Setenv("SPANNER_DATABASE", "your_spanner_db") + + t.Setenv("ALLOYDB_POSTGRES_PROJECT", "your_gcp_project_id") + t.Setenv("ALLOYDB_POSTGRES_REGION", "your_gcp_region") + t.Setenv("ALLOYDB_POSTGRES_CLUSTER", "your_alloydb_cluster") + t.Setenv("ALLOYDB_POSTGRES_INSTANCE", "your_alloydb_instance") + t.Setenv("ALLOYDB_POSTGRES_DATABASE", "your_alloydb_db") + t.Setenv("ALLOYDB_POSTGRES_USER", "your_alloydb_user") + t.Setenv("ALLOYDB_POSTGRES_PASSWORD", "your_alloydb_password") + + t.Setenv("ALLOYDB_OMNI_HOST", "localhost") + t.Setenv("ALLOYDB_OMNI_PORT", "5432") + t.Setenv("ALLOYDB_OMNI_DATABASE", "your_alloydb_db") + t.Setenv("ALLOYDB_OMNI_USER", "your_alloydb_user") + t.Setenv("ALLOYDB_OMNI_PASSWORD", "your_alloydb_password") + + t.Setenv("CLICKHOUSE_PROTOCOL", "your_clickhouse_protocol") + t.Setenv("CLICKHOUSE_DATABASE", "your_clickhouse_database") + t.Setenv("CLICKHOUSE_PASSWORD", "your_clickhouse_password") + t.Setenv("CLICKHOUSE_USER", "your_clickhouse_user") + t.Setenv("CLICKHOUSE_HOST", "your_clickhosue_host") + t.Setenv("CLICKHOUSE_PORT", "8123") + + t.Setenv("CLOUD_SQL_POSTGRES_PROJECT", "your_pg_project") + t.Setenv("CLOUD_SQL_POSTGRES_INSTANCE", "your_pg_instance") + t.Setenv("CLOUD_SQL_POSTGRES_DATABASE", "your_pg_db") + t.Setenv("CLOUD_SQL_POSTGRES_REGION", "your_pg_region") + t.Setenv("CLOUD_SQL_POSTGRES_USER", "your_pg_user") + t.Setenv("CLOUD_SQL_POSTGRES_PASS", "your_pg_pass") + + t.Setenv("CLOUD_SQL_MYSQL_PROJECT", "your_gcp_project_id") + t.Setenv("CLOUD_SQL_MYSQL_REGION", "your_gcp_region") + t.Setenv("CLOUD_SQL_MYSQL_INSTANCE", "your_instance") + t.Setenv("CLOUD_SQL_MYSQL_DATABASE", "your_cloudsql_mysql_db") + t.Setenv("CLOUD_SQL_MYSQL_USER", "your_cloudsql_mysql_user") + t.Setenv("CLOUD_SQL_MYSQL_PASSWORD", "your_cloudsql_mysql_password") + + t.Setenv("CLOUD_SQL_MSSQL_PROJECT", "your_gcp_project_id") + t.Setenv("CLOUD_SQL_MSSQL_REGION", "your_gcp_region") + t.Setenv("CLOUD_SQL_MSSQL_INSTANCE", "your_cloudsql_mssql_instance") + t.Setenv("CLOUD_SQL_MSSQL_DATABASE", "your_cloudsql_mssql_db") + t.Setenv("CLOUD_SQL_MSSQL_IP_ADDRESS", "127.0.0.1") + t.Setenv("CLOUD_SQL_MSSQL_USER", "your_cloudsql_mssql_user") + t.Setenv("CLOUD_SQL_MSSQL_PASSWORD", "your_cloudsql_mssql_password") + t.Setenv("CLOUD_SQL_POSTGRES_PASSWORD", "your_cloudsql_pg_password") + + t.Setenv("SERVERLESS_SPARK_PROJECT", "your_gcp_project_id") + t.Setenv("SERVERLESS_SPARK_LOCATION", "your_gcp_location") + + t.Setenv("POSTGRES_HOST", "localhost") + t.Setenv("POSTGRES_PORT", "5432") + t.Setenv("POSTGRES_DATABASE", "your_postgres_db") + t.Setenv("POSTGRES_USER", "your_postgres_user") + t.Setenv("POSTGRES_PASSWORD", "your_postgres_password") + + t.Setenv("MYSQL_HOST", "localhost") + t.Setenv("MYSQL_PORT", "3306") + t.Setenv("MYSQL_DATABASE", "your_mysql_db") + t.Setenv("MYSQL_USER", "your_mysql_user") + t.Setenv("MYSQL_PASSWORD", "your_mysql_password") + + t.Setenv("MSSQL_HOST", "localhost") + t.Setenv("MSSQL_PORT", "1433") + t.Setenv("MSSQL_DATABASE", "your_mssql_db") + t.Setenv("MSSQL_USER", "your_mssql_user") + t.Setenv("MSSQL_PASSWORD", "your_mssql_password") + + t.Setenv("MINDSDB_HOST", "localhost") + t.Setenv("MINDSDB_PORT", "47334") + t.Setenv("MINDSDB_DATABASE", "your_mindsdb_db") + t.Setenv("MINDSDB_USER", "your_mindsdb_user") + t.Setenv("MINDSDB_PASS", "your_mindsdb_password") + + t.Setenv("LOOKER_BASE_URL", "https://your_company.looker.com") + t.Setenv("LOOKER_CLIENT_ID", "your_looker_client_id") + t.Setenv("LOOKER_CLIENT_SECRET", "your_looker_client_secret") + t.Setenv("LOOKER_VERIFY_SSL", "true") + + t.Setenv("LOOKER_PROJECT", "your_project_id") + t.Setenv("LOOKER_LOCATION", "us") + + t.Setenv("SQLITE_DATABASE", "test.db") + + t.Setenv("NEO4J_URI", "bolt://localhost:7687") + t.Setenv("NEO4J_DATABASE", "neo4j") + t.Setenv("NEO4J_USERNAME", "your_neo4j_user") + t.Setenv("NEO4J_PASSWORD", "your_neo4j_password") + + t.Setenv("CLOUD_HEALTHCARE_PROJECT", "your_gcp_project_id") + t.Setenv("CLOUD_HEALTHCARE_REGION", "your_gcp_region") + t.Setenv("CLOUD_HEALTHCARE_DATASET", "your_healthcare_dataset") + + t.Setenv("SNOWFLAKE_ACCOUNT", "your_account") + t.Setenv("SNOWFLAKE_USER", "your_username") + t.Setenv("SNOWFLAKE_PASSWORD", "your_pass") + t.Setenv("SNOWFLAKE_DATABASE", "your_db") + t.Setenv("SNOWFLAKE_SCHEMA", "your_schema") + t.Setenv("SNOWFLAKE_WAREHOUSE", "your_wh") + t.Setenv("SNOWFLAKE_ROLE", "your_role") + + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + name string + in []byte + wantToolset server.ToolsetConfigs + }{ + { + name: "alloydb omni prebuilt tools", + in: alloydb_omni_config, + wantToolset: server.ToolsetConfigs{ + "alloydb_omni_database_tools": tools.ToolsetConfig{ + Name: "alloydb_omni_database_tools", + ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_columnar_configurations", "list_columnar_recommended_columns", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"}, + }, + }, + }, + { + name: "alloydb postgres admin prebuilt tools", + in: alloydb_admin_config, + wantToolset: server.ToolsetConfigs{ + "alloydb_postgres_admin_tools": tools.ToolsetConfig{ + Name: "alloydb_postgres_admin_tools", + ToolNames: []string{"create_cluster", "wait_for_operation", "create_instance", "list_clusters", "list_instances", "list_users", "create_user", "get_cluster", "get_instance", "get_user"}, + }, + }, + }, + { + name: "cloudsql pg admin prebuilt tools", + in: cloudsqlpg_admin_config, + wantToolset: server.ToolsetConfigs{ + "cloud_sql_postgres_admin_tools": tools.ToolsetConfig{ + Name: "cloud_sql_postgres_admin_tools", + ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance", "create_backup", "restore_backup"}, + }, + }, + }, + { + name: "cloudsql mysql admin prebuilt tools", + in: cloudsqlmysql_admin_config, + wantToolset: server.ToolsetConfigs{ + "cloud_sql_mysql_admin_tools": tools.ToolsetConfig{ + Name: "cloud_sql_mysql_admin_tools", + ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup", "restore_backup"}, + }, + }, + }, + { + name: "cloudsql mssql admin prebuilt tools", + in: cloudsqlmssql_admin_config, + wantToolset: server.ToolsetConfigs{ + "cloud_sql_mssql_admin_tools": tools.ToolsetConfig{ + Name: "cloud_sql_mssql_admin_tools", + ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup", "restore_backup"}, + }, + }, + }, + { + name: "alloydb prebuilt tools", + in: alloydb_config, + wantToolset: server.ToolsetConfigs{ + "alloydb_postgres_database_tools": tools.ToolsetConfig{ + Name: "alloydb_postgres_database_tools", + ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"}, + }, + }, + }, + { + name: "bigquery prebuilt tools", + in: bigquery_config, + wantToolset: server.ToolsetConfigs{ + "bigquery_database_tools": tools.ToolsetConfig{ + Name: "bigquery_database_tools", + ToolNames: []string{"analyze_contribution", "ask_data_insights", "execute_sql", "forecast", "get_dataset_info", "get_table_info", "list_dataset_ids", "list_table_ids", "search_catalog"}, + }, + }, + }, + { + name: "clickhouse prebuilt tools", + in: clickhouse_config, + wantToolset: server.ToolsetConfigs{ + "clickhouse_database_tools": tools.ToolsetConfig{ + Name: "clickhouse_database_tools", + ToolNames: []string{"execute_sql", "list_databases", "list_tables"}, + }, + }, + }, + { + name: "cloudsqlpg prebuilt tools", + in: cloudsqlpg_config, + wantToolset: server.ToolsetConfigs{ + "cloud_sql_postgres_database_tools": tools.ToolsetConfig{ + Name: "cloud_sql_postgres_database_tools", + ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"}, + }, + }, + }, + { + name: "cloudsqlmysql prebuilt tools", + in: cloudsqlmysql_config, + wantToolset: server.ToolsetConfigs{ + "cloud_sql_mysql_database_tools": tools.ToolsetConfig{ + Name: "cloud_sql_mysql_database_tools", + ToolNames: []string{"execute_sql", "list_tables", "get_query_plan", "list_active_queries", "list_tables_missing_unique_indexes", "list_table_fragmentation"}, + }, + }, + }, + { + name: "cloudsqlmssql prebuilt tools", + in: cloudsqlmssql_config, + wantToolset: server.ToolsetConfigs{ + "cloud_sql_mssql_database_tools": tools.ToolsetConfig{ + Name: "cloud_sql_mssql_database_tools", + ToolNames: []string{"execute_sql", "list_tables"}, + }, + }, + }, + { + name: "dataplex prebuilt tools", + in: dataplex_config, + wantToolset: server.ToolsetConfigs{ + "dataplex_tools": tools.ToolsetConfig{ + Name: "dataplex_tools", + ToolNames: []string{"search_entries", "lookup_entry", "search_aspect_types"}, + }, + }, + }, + { + name: "serverless spark prebuilt tools", + in: serverless_spark_config, + wantToolset: server.ToolsetConfigs{ + "serverless_spark_tools": tools.ToolsetConfig{ + Name: "serverless_spark_tools", + ToolNames: []string{"list_batches", "get_batch", "cancel_batch", "create_pyspark_batch", "create_spark_batch"}, + }, + }, + }, + { + name: "firestore prebuilt tools", + in: firestoreconfig, + wantToolset: server.ToolsetConfigs{ + "firestore_database_tools": tools.ToolsetConfig{ + Name: "firestore_database_tools", + ToolNames: []string{"get_documents", "add_documents", "update_document", "list_collections", "delete_documents", "query_collection", "get_rules", "validate_rules"}, + }, + }, + }, + { + name: "mysql prebuilt tools", + in: mysql_config, + wantToolset: server.ToolsetConfigs{ + "mysql_database_tools": tools.ToolsetConfig{ + Name: "mysql_database_tools", + ToolNames: []string{"execute_sql", "list_tables", "get_query_plan", "list_active_queries", "list_tables_missing_unique_indexes", "list_table_fragmentation"}, + }, + }, + }, + { + name: "mssql prebuilt tools", + in: mssql_config, + wantToolset: server.ToolsetConfigs{ + "mssql_database_tools": tools.ToolsetConfig{ + Name: "mssql_database_tools", + ToolNames: []string{"execute_sql", "list_tables"}, + }, + }, + }, + { + name: "looker prebuilt tools", + in: looker_config, + wantToolset: server.ToolsetConfigs{ + "looker_tools": tools.ToolsetConfig{ + Name: "looker_tools", + ToolNames: []string{"get_models", "get_explores", "get_dimensions", "get_measures", "get_filters", "get_parameters", "query", "query_sql", "query_url", "get_looks", "run_look", "make_look", "get_dashboards", "run_dashboard", "make_dashboard", "add_dashboard_element", "add_dashboard_filter", "generate_embed_url", "health_pulse", "health_analyze", "health_vacuum", "dev_mode", "get_projects", "get_project_files", "get_project_file", "create_project_file", "update_project_file", "delete_project_file", "validate_project", "get_connections", "get_connection_schemas", "get_connection_databases", "get_connection_tables", "get_connection_table_columns"}, + }, + }, + }, + { + name: "looker-conversational-analytics prebuilt tools", + in: lookerca_config, + wantToolset: server.ToolsetConfigs{ + "looker_conversational_analytics_tools": tools.ToolsetConfig{ + Name: "looker_conversational_analytics_tools", + ToolNames: []string{"ask_data_insights", "get_models", "get_explores"}, + }, + }, + }, + { + name: "postgres prebuilt tools", + in: postgresconfig, + wantToolset: server.ToolsetConfigs{ + "postgres_database_tools": tools.ToolsetConfig{ + Name: "postgres_database_tools", + ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"}, + }, + }, + }, + { + name: "spanner prebuilt tools", + in: spanner_config, + wantToolset: server.ToolsetConfigs{ + "spanner-database-tools": tools.ToolsetConfig{ + Name: "spanner-database-tools", + ToolNames: []string{"execute_sql", "execute_sql_dql", "list_tables", "list_graphs"}, + }, + }, + }, + { + name: "spanner pg prebuilt tools", + in: spannerpg_config, + wantToolset: server.ToolsetConfigs{ + "spanner_postgres_database_tools": tools.ToolsetConfig{ + Name: "spanner_postgres_database_tools", + ToolNames: []string{"execute_sql", "execute_sql_dql", "list_tables"}, + }, + }, + }, + { + name: "mindsdb prebuilt tools", + in: mindsdb_config, + wantToolset: server.ToolsetConfigs{ + "mindsdb-tools": tools.ToolsetConfig{ + Name: "mindsdb-tools", + ToolNames: []string{"mindsdb-execute-sql", "mindsdb-sql"}, + }, + }, + }, + { + name: "sqlite prebuilt tools", + in: sqlite_config, + wantToolset: server.ToolsetConfigs{ + "sqlite_database_tools": tools.ToolsetConfig{ + Name: "sqlite_database_tools", + ToolNames: []string{"execute_sql", "list_tables"}, + }, + }, + }, + { + name: "neo4j prebuilt tools", + in: neo4jconfig, + wantToolset: server.ToolsetConfigs{ + "neo4j_database_tools": tools.ToolsetConfig{ + Name: "neo4j_database_tools", + ToolNames: []string{"execute_cypher", "get_schema"}, + }, + }, + }, + { + name: "alloydb postgres observability prebuilt tools", + in: alloydbobsvconfig, + wantToolset: server.ToolsetConfigs{ + "alloydb_postgres_cloud_monitoring_tools": tools.ToolsetConfig{ + Name: "alloydb_postgres_cloud_monitoring_tools", + ToolNames: []string{"get_system_metrics", "get_query_metrics"}, + }, + }, + }, + { + name: "cloudsql postgres observability prebuilt tools", + in: cloudsqlpgobsvconfig, + wantToolset: server.ToolsetConfigs{ + "cloud_sql_postgres_cloud_monitoring_tools": tools.ToolsetConfig{ + Name: "cloud_sql_postgres_cloud_monitoring_tools", + ToolNames: []string{"get_system_metrics", "get_query_metrics"}, + }, + }, + }, + { + name: "cloudsql mysql observability prebuilt tools", + in: cloudsqlmysqlobsvconfig, + wantToolset: server.ToolsetConfigs{ + "cloud_sql_mysql_cloud_monitoring_tools": tools.ToolsetConfig{ + Name: "cloud_sql_mysql_cloud_monitoring_tools", + ToolNames: []string{"get_system_metrics", "get_query_metrics"}, + }, + }, + }, + { + name: "cloudsql mssql observability prebuilt tools", + in: cloudsqlmssqlobsvconfig, + wantToolset: server.ToolsetConfigs{ + "cloud_sql_mssql_cloud_monitoring_tools": tools.ToolsetConfig{ + Name: "cloud_sql_mssql_cloud_monitoring_tools", + ToolNames: []string{"get_system_metrics"}, + }, + }, + }, + { + name: "cloud healthcare prebuilt tools", + in: cloudhealthcare_config, + wantToolset: server.ToolsetConfigs{ + "cloud_healthcare_dataset_tools": tools.ToolsetConfig{ + Name: "cloud_healthcare_dataset_tools", + ToolNames: []string{"get_dataset", "list_dicom_stores", "list_fhir_stores"}, + }, + "cloud_healthcare_fhir_tools": tools.ToolsetConfig{ + Name: "cloud_healthcare_fhir_tools", + ToolNames: []string{"get_fhir_store", "get_fhir_store_metrics", "get_fhir_resource", "fhir_patient_search", "fhir_patient_everything", "fhir_fetch_page"}, + }, + "cloud_healthcare_dicom_tools": tools.ToolsetConfig{ + Name: "cloud_healthcare_dicom_tools", + ToolNames: []string{"get_dicom_store", "get_dicom_store_metrics", "search_dicom_studies", "search_dicom_series", "search_dicom_instances", "retrieve_rendered_dicom_instance"}, + }, + }, + }, + { + name: "Snowflake prebuilt tool", + in: snowflake_config, + wantToolset: server.ToolsetConfigs{ + "snowflake_tools": tools.ToolsetConfig{ + Name: "snowflake_tools", + ToolNames: []string{"execute_sql", "list_tables"}, + }, + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + toolsFile, err := parseToolsFile(ctx, tc.in) + if err != nil { + t.Fatalf("failed to parse input: %v", err) + } + if diff := cmp.Diff(tc.wantToolset, toolsFile.Toolsets); diff != "" { + t.Fatalf("incorrect tools parse: diff %v", diff) + } + // Prebuilt configs do not have prompts, so assert empty maps. + if len(toolsFile.Prompts) != 0 { + t.Fatalf("expected empty prompts map for prebuilt config, got: %v", toolsFile.Prompts) + } + }) + } +} + +func TestMergeToolsFiles(t *testing.T) { + file1 := ToolsFile{ + Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}}, + Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}}, + Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}}, + EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}}, + } + file2 := ToolsFile{ + AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}}, + Tools: server.ToolConfigs{"tool2": http.Config{Name: "tool2"}}, + Toolsets: server.ToolsetConfigs{"set2": tools.ToolsetConfig{Name: "set2"}}, + } + fileWithConflicts := ToolsFile{ + Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}}, + Tools: server.ToolConfigs{"tool2": http.Config{Name: "tool2"}}, + } + + testCases := []struct { + name string + files []ToolsFile + want ToolsFile + wantErr bool + }{ + { + name: "merge two distinct files", + files: []ToolsFile{file1, file2}, + want: ToolsFile{ + Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}}, + AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}}, + Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}, "tool2": http.Config{Name: "tool2"}}, + Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}, "set2": tools.ToolsetConfig{Name: "set2"}}, + Prompts: server.PromptConfigs{}, + EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}}, + }, + wantErr: false, + }, + { + name: "merge with conflicts", + files: []ToolsFile{file1, file2, fileWithConflicts}, + wantErr: true, + }, + { + name: "merge single file", + files: []ToolsFile{file1}, + want: ToolsFile{ + Sources: file1.Sources, + AuthServices: make(server.AuthServiceConfigs), + EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}}, + Tools: file1.Tools, + Toolsets: file1.Toolsets, + Prompts: server.PromptConfigs{}, + }, + }, + { + name: "merge empty list", + files: []ToolsFile{}, + want: ToolsFile{ + Sources: make(server.SourceConfigs), + AuthServices: make(server.AuthServiceConfigs), + EmbeddingModels: make(server.EmbeddingModelConfigs), + Tools: make(server.ToolConfigs), + Toolsets: make(server.ToolsetConfigs), + Prompts: server.PromptConfigs{}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := mergeToolsFiles(tc.files...) + if (err != nil) != tc.wantErr { + t.Fatalf("mergeToolsFiles() error = %v, wantErr %v", err, tc.wantErr) + } + if !tc.wantErr { + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("mergeToolsFiles() mismatch (-want +got):\n%s", diff) + } + } else { + if err == nil { + t.Fatal("expected an error for conflicting files but got none") + } + if !strings.Contains(err.Error(), "resource conflicts detected") { + t.Errorf("expected conflict error, but got: %v", err) + } + } + }) + } +} + +func TestParameterReferenceValidation(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + // Base template + baseYaml := ` +sources: + dummy-source: + kind: http + baseUrl: http://example.com +tools: + test-tool: + kind: postgres-sql + source: dummy-source + description: test tool + statement: SELECT 1; + parameters: +%s` + + tcs := []struct { + desc string + params string + wantErr bool + errSubstr string + }{ + { + desc: "valid backward reference", + params: ` + - name: source_param + type: string + description: source + - name: copy_param + type: string + description: copy + valueFromParam: source_param`, + wantErr: false, + }, + { + desc: "valid forward reference (out of order)", + params: ` + - name: copy_param + type: string + description: copy + valueFromParam: source_param + - name: source_param + type: string + description: source`, + wantErr: false, + }, + { + desc: "invalid missing reference", + params: ` + - name: copy_param + type: string + description: copy + valueFromParam: non_existent_param`, + wantErr: true, + errSubstr: "references '\"non_existent_param\"' in the 'valueFromParam' field", + }, + { + desc: "invalid self reference", + params: ` + - name: myself + type: string + description: self + valueFromParam: myself`, + wantErr: true, + errSubstr: "parameter \"myself\" cannot copy value from itself", + }, + { + desc: "multiple valid references", + params: ` + - name: a + type: string + description: a + - name: b + type: string + description: b + valueFromParam: a + - name: c + type: string + description: c + valueFromParam: a`, + wantErr: false, + }, + } + + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Indent parameters to match YAML structure + yamlContent := fmt.Sprintf(baseYaml, tc.params) + + _, err := parseToolsFile(ctx, []byte(yamlContent)) + + if tc.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tc.errSubstr) { + t.Errorf("error %q does not contain expected substring %q", err.Error(), tc.errSubstr) + } + } else { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + } + }) + } +} diff --git a/cmd/options.go b/cmd/options.go deleted file mode 100644 index b87a7e6d55..0000000000 --- a/cmd/options.go +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cmd - -import ( - "io" -) - -// Option is a function that configures a Command. -type Option func(*Command) - -// WithStreams overrides the default writer. -func WithStreams(out, err io.Writer) Option { - return func(c *Command) { - c.outStream = out - c.errStream = err - } -} diff --git a/cmd/root.go b/cmd/root.go index 3d62c11dc8..33383366ea 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -15,7 +15,6 @@ package cmd import ( - "bytes" "context" _ "embed" "fmt" @@ -24,7 +23,6 @@ import ( "os" "os/signal" "path/filepath" - "regexp" "runtime" "slices" "strings" @@ -32,261 +30,18 @@ import ( "time" "github.com/fsnotify/fsnotify" - yaml "github.com/goccy/go-yaml" + // Importing the cmd/internal package also import packages for side effect of registration + "github.com/googleapis/genai-toolbox/cmd/internal" + "github.com/googleapis/genai-toolbox/cmd/internal/invoke" + "github.com/googleapis/genai-toolbox/cmd/internal/skills" "github.com/googleapis/genai-toolbox/internal/auth" - "github.com/googleapis/genai-toolbox/internal/cli/invoke" - "github.com/googleapis/genai-toolbox/internal/cli/skills" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" - "github.com/googleapis/genai-toolbox/internal/log" - "github.com/googleapis/genai-toolbox/internal/prebuiltconfigs" "github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/server" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/telemetry" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" - - // Import prompt packages for side effect of registration - _ "github.com/googleapis/genai-toolbox/internal/prompts/custom" - - // Import tool packages for side effect of registration - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbcreatecluster" - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbcreateinstance" - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbcreateuser" - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbgetcluster" - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbgetinstance" - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbgetuser" - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydblistclusters" - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydblistinstances" - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydblistusers" - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbwaitforoperation" - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydbainl" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryanalyzecontribution" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryconversationalanalytics" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryforecast" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerygetdatasetinfo" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerygettableinfo" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerylistdatasetids" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerylisttableids" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysearchcatalog" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysql" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigtable" - _ "github.com/googleapis/genai-toolbox/internal/tools/cassandra/cassandracql" - _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouseexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselistdatabases" - _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselisttables" - _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhousesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudgda" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetdataset" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstore" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstoremetrics" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetfhirresource" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstore" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstoremetrics" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarelistdicomstores" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarelistfhirstores" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcareretrieverendereddicominstance" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicominstances" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomseries" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudloggingadmin/cloudloggingadminlistlognames" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudloggingadmin/cloudloggingadminlistresourcetypes" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudloggingadmin/cloudloggingadminquerylogs" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudmonitoring" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcloneinstance" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatebackup" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatedatabase" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreateusers" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlgetinstances" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistdatabases" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistinstances" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlrestorebackup" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlwaitforoperation" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlpg/cloudsqlpgupgradeprecheck" - _ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdbexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdblistschemas" - _ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdblisttables" - _ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdbsql" - _ "github.com/googleapis/genai-toolbox/internal/tools/couchbase" - _ "github.com/googleapis/genai-toolbox/internal/tools/dataform/dataformcompilelocal" - _ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexlookupentry" - _ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexsearchaspecttypes" - _ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexsearchentries" - _ "github.com/googleapis/genai-toolbox/internal/tools/dgraph" - _ "github.com/googleapis/genai-toolbox/internal/tools/elasticsearch/elasticsearchesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/firebird/firebirdexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/firebird/firebirdsql" - _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoreadddocuments" - _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoredeletedocuments" - _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoregetdocuments" - _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoregetrules" - _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorelistcollections" - _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorequery" - _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorequerycollection" - _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoreupdatedocument" - _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorevalidaterules" - _ "github.com/googleapis/genai-toolbox/internal/tools/http" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardelement" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardfilter" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerconversationalanalytics" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercreateprojectfile" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdeleteprojectfile" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdevmode" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergenerateembedurl" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiondatabases" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnections" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectionschemas" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontablecolumns" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontables" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdashboards" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdimensions" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetexplores" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetfilters" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetlooks" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetmeasures" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetmodels" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetparameters" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetprojectfile" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetprojectfiles" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetprojects" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerhealthanalyze" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerhealthpulse" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerhealthvacuum" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookermakedashboard" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookermakelook" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerquery" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerquerysql" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerqueryurl" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerrundashboard" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerrunlook" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerupdateprojectfile" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookervalidateproject" - _ "github.com/googleapis/genai-toolbox/internal/tools/mindsdb/mindsdbexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/mindsdb/mindsdbsql" - _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbaggregate" - _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbdeletemany" - _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbdeleteone" - _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbfind" - _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbfindone" - _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbinsertmany" - _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbinsertone" - _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbupdatemany" - _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbupdateone" - _ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqllisttables" - _ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlsql" - _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlgetqueryplan" - _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllistactivequeries" - _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttablefragmentation" - _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttables" - _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttablesmissinguniqueindexes" - _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlsql" - _ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jcypher" - _ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jexecutecypher" - _ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema" - _ "github.com/googleapis/genai-toolbox/internal/tools/oceanbase/oceanbaseexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/oceanbase/oceanbasesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/oracle/oracleexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/oracle/oraclesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresdatabaseoverview" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresgetcolumncardinality" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistactivequeries" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistavailableextensions" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistdatabasestats" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistindexes" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistinstalledextensions" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistlocks" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistpgsettings" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistpublicationtables" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistquerystats" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistroles" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistschemas" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistsequences" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresliststoredprocedure" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttables" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttablespaces" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttablestats" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttriggers" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistviews" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslongrunningtransactions" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresreplicationstats" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql" - _ "github.com/googleapis/genai-toolbox/internal/tools/redis" - _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcancelbatch" - _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch" - _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcreatesparkbatch" - _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch" - _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches" - _ "github.com/googleapis/genai-toolbox/internal/tools/singlestore/singlestoreexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/singlestore/singlestoresql" - _ "github.com/googleapis/genai-toolbox/internal/tools/snowflake/snowflakeexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/snowflake/snowflakesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlistgraphs" - _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlisttables" - _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannersql" - _ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqliteexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqlitesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbsql" - _ "github.com/googleapis/genai-toolbox/internal/tools/trino/trinoexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/trino/trinosql" - _ "github.com/googleapis/genai-toolbox/internal/tools/utility/wait" - _ "github.com/googleapis/genai-toolbox/internal/tools/valkey" - _ "github.com/googleapis/genai-toolbox/internal/tools/yugabytedbsql" - "github.com/spf13/cobra" - - _ "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" - _ "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - _ "github.com/googleapis/genai-toolbox/internal/sources/bigquery" - _ "github.com/googleapis/genai-toolbox/internal/sources/bigtable" - _ "github.com/googleapis/genai-toolbox/internal/sources/cassandra" - _ "github.com/googleapis/genai-toolbox/internal/sources/clickhouse" - _ "github.com/googleapis/genai-toolbox/internal/sources/cloudgda" - _ "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare" - _ "github.com/googleapis/genai-toolbox/internal/sources/cloudloggingadmin" - _ "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring" - _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" - _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql" - _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql" - _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - _ "github.com/googleapis/genai-toolbox/internal/sources/cockroachdb" - _ "github.com/googleapis/genai-toolbox/internal/sources/couchbase" - _ "github.com/googleapis/genai-toolbox/internal/sources/dataplex" - _ "github.com/googleapis/genai-toolbox/internal/sources/dgraph" - _ "github.com/googleapis/genai-toolbox/internal/sources/elasticsearch" - _ "github.com/googleapis/genai-toolbox/internal/sources/firebird" - _ "github.com/googleapis/genai-toolbox/internal/sources/firestore" - _ "github.com/googleapis/genai-toolbox/internal/sources/http" - _ "github.com/googleapis/genai-toolbox/internal/sources/looker" - _ "github.com/googleapis/genai-toolbox/internal/sources/mindsdb" - _ "github.com/googleapis/genai-toolbox/internal/sources/mongodb" - _ "github.com/googleapis/genai-toolbox/internal/sources/mssql" - _ "github.com/googleapis/genai-toolbox/internal/sources/mysql" - _ "github.com/googleapis/genai-toolbox/internal/sources/neo4j" - _ "github.com/googleapis/genai-toolbox/internal/sources/oceanbase" - _ "github.com/googleapis/genai-toolbox/internal/sources/oracle" - _ "github.com/googleapis/genai-toolbox/internal/sources/postgres" - _ "github.com/googleapis/genai-toolbox/internal/sources/redis" - _ "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark" - _ "github.com/googleapis/genai-toolbox/internal/sources/singlestore" - _ "github.com/googleapis/genai-toolbox/internal/sources/snowflake" - _ "github.com/googleapis/genai-toolbox/internal/sources/spanner" - _ "github.com/googleapis/genai-toolbox/internal/sources/sqlite" - _ "github.com/googleapis/genai-toolbox/internal/sources/tidb" - _ "github.com/googleapis/genai-toolbox/internal/sources/trino" - _ "github.com/googleapis/genai-toolbox/internal/sources/valkey" - _ "github.com/googleapis/genai-toolbox/internal/sources/yugabytedb" ) var ( @@ -315,425 +70,74 @@ func semanticVersion() string { return v } +// GenerateCommand returns a new Command object with the specified IO streams +// This is used for integration test package +func GenerateCommand(out, err io.Writer) *cobra.Command { + opts := internal.NewToolboxOptions(internal.WithIOStreams(out, err)) + return NewCommand(opts) +} + // Execute adds all child commands to the root command and sets flags appropriately. // This is called by main.main(). It only needs to happen once to the rootCmd. func Execute() { - if err := NewCommand().Execute(); err != nil { + // Initialize options + opts := internal.NewToolboxOptions() + + if err := NewCommand(opts).Execute(); err != nil { exit := 1 os.Exit(exit) } } -// Command represents an invocation of the CLI. -type Command struct { - *cobra.Command - - cfg server.ServerConfig - logger log.Logger - tools_file string - tools_files []string - tools_folder string - prebuiltConfigs []string - inStream io.Reader - outStream io.Writer - errStream io.Writer -} - // NewCommand returns a Command object representing an invocation of the CLI. -func NewCommand(opts ...Option) *Command { - in := os.Stdin - out := os.Stdout - err := os.Stderr - - baseCmd := &cobra.Command{ +func NewCommand(opts *internal.ToolboxOptions) *cobra.Command { + cmd := &cobra.Command{ Use: "toolbox", Version: versionString, SilenceErrors: true, } - cmd := &Command{ - Command: baseCmd, - inStream: in, - outStream: out, - errStream: err, - } - - for _, o := range opts { - o(cmd) - } // Do not print Usage on runtime error cmd.SilenceUsage = true // Set server version - cmd.cfg.Version = versionString + opts.Cfg.Version = versionString // set baseCmd in, out and err the same as cmd. - baseCmd.SetIn(cmd.inStream) - baseCmd.SetOut(cmd.outStream) - baseCmd.SetErr(cmd.errStream) + cmd.SetIn(opts.IOStreams.In) + cmd.SetOut(opts.IOStreams.Out) + cmd.SetErr(opts.IOStreams.ErrOut) + + // setup flags that are common across all commands + internal.PersistentFlags(cmd, opts) flags := cmd.Flags() - persistentFlags := cmd.PersistentFlags() - flags.StringVarP(&cmd.cfg.Address, "address", "a", "127.0.0.1", "Address of the interface the server will listen on.") - flags.IntVarP(&cmd.cfg.Port, "port", "p", 5000, "Port the server will listen on.") + flags.StringVarP(&opts.Cfg.Address, "address", "a", "127.0.0.1", "Address of the interface the server will listen on.") + flags.IntVarP(&opts.Cfg.Port, "port", "p", 5000, "Port the server will listen on.") - flags.StringVar(&cmd.tools_file, "tools_file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.") + flags.StringVar(&opts.ToolsFile, "tools_file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.") // deprecate tools_file _ = flags.MarkDeprecated("tools_file", "please use --tools-file instead") - persistentFlags.StringVar(&cmd.tools_file, "tools-file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.") - persistentFlags.StringSliceVar(&cmd.tools_files, "tools-files", []string{}, "Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --tools-file, or --tools-folder.") - persistentFlags.StringVar(&cmd.tools_folder, "tools-folder", "", "Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --tools-file, or --tools-files.") - persistentFlags.Var(&cmd.cfg.LogLevel, "log-level", "Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'.") - persistentFlags.Var(&cmd.cfg.LoggingFormat, "logging-format", "Specify logging format to use. Allowed: 'standard' or 'JSON'.") - persistentFlags.BoolVar(&cmd.cfg.TelemetryGCP, "telemetry-gcp", false, "Enable exporting directly to Google Cloud Monitoring.") - persistentFlags.StringVar(&cmd.cfg.TelemetryOTLP, "telemetry-otlp", "", "Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318')") - persistentFlags.StringVar(&cmd.cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.") - // Fetch prebuilt tools sources to customize the help description - prebuiltHelp := fmt.Sprintf( - "Use a prebuilt tool configuration by source type. Allowed: '%s'. Can be specified multiple times.", - strings.Join(prebuiltconfigs.GetPrebuiltSources(), "', '"), - ) - persistentFlags.StringSliceVar(&cmd.prebuiltConfigs, "prebuilt", []string{}, prebuiltHelp) - flags.BoolVar(&cmd.cfg.Stdio, "stdio", false, "Listens via MCP STDIO instead of acting as a remote HTTP server.") - flags.BoolVar(&cmd.cfg.DisableReload, "disable-reload", false, "Disables dynamic reloading of tools file.") - flags.BoolVar(&cmd.cfg.UI, "ui", false, "Launches the Toolbox UI web server.") + flags.BoolVar(&opts.Cfg.Stdio, "stdio", false, "Listens via MCP STDIO instead of acting as a remote HTTP server.") + flags.BoolVar(&opts.Cfg.DisableReload, "disable-reload", false, "Disables dynamic reloading of tools file.") + flags.BoolVar(&opts.Cfg.UI, "ui", false, "Launches the Toolbox UI web server.") // TODO: Insecure by default. Might consider updating this for v1.0.0 - flags.StringSliceVar(&cmd.cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.") - flags.StringSliceVar(&cmd.cfg.AllowedHosts, "allowed-hosts", []string{"*"}, "Specifies a list of hosts permitted to access this server. Defaults to '*'.") - persistentFlags.StringSliceVar(&cmd.cfg.UserAgentMetadata, "user-agent-metadata", []string{}, "Appends additional metadata to the User-Agent.") + flags.StringSliceVar(&opts.Cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.") + flags.StringSliceVar(&opts.Cfg.AllowedHosts, "allowed-hosts", []string{"*"}, "Specifies a list of hosts permitted to access this server. Defaults to '*'.") // wrap RunE command so that we have access to original Command object - cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) } + cmd.RunE = func(*cobra.Command, []string) error { return run(cmd, opts) } // Register subcommands for tool invocation - baseCmd.AddCommand(invoke.NewCommand(cmd)) + cmd.AddCommand(invoke.NewCommand(opts)) // Register subcommands for skill generation - baseCmd.AddCommand(skills.NewCommand(cmd)) + cmd.AddCommand(skills.NewCommand(opts)) return cmd } -type ToolsFile struct { - Sources server.SourceConfigs `yaml:"sources"` - AuthServices server.AuthServiceConfigs `yaml:"authServices"` - EmbeddingModels server.EmbeddingModelConfigs `yaml:"embeddingModels"` - Tools server.ToolConfigs `yaml:"tools"` - Toolsets server.ToolsetConfigs `yaml:"toolsets"` - Prompts server.PromptConfigs `yaml:"prompts"` -} - -// parseEnv replaces environment variables ${ENV_NAME} with their values. -// also support ${ENV_NAME:default_value}. -func parseEnv(input string) (string, error) { - re := regexp.MustCompile(`\$\{(\w+)(:([^}]*))?\}`) - - var err error - output := re.ReplaceAllStringFunc(input, func(match string) string { - parts := re.FindStringSubmatch(match) - - // extract the variable name - variableName := parts[1] - if value, found := os.LookupEnv(variableName); found { - return value - } - if len(parts) >= 4 && parts[2] != "" { - return parts[3] - } - err = fmt.Errorf("environment variable not found: %q", variableName) - return "" - }) - return output, err -} - -func convertToolsFile(raw []byte) ([]byte, error) { - var input yaml.MapSlice - decoder := yaml.NewDecoder(bytes.NewReader(raw), yaml.UseOrderedMap()) - - // convert to tools file v2 - var buf bytes.Buffer - encoder := yaml.NewEncoder(&buf) - - v1keys := []string{"sources", "authSources", "authServices", "embeddingModels", "tools", "toolsets", "prompts"} - for { - if err := decoder.Decode(&input); err != nil { - if err == io.EOF { - break - } - return nil, err - } - for _, item := range input { - key, ok := item.Key.(string) - if !ok { - return nil, fmt.Errorf("unexpected non-string key in input: %v", item.Key) - } - // check if the key is config file v1's key - if slices.Contains(v1keys, key) { - // check if value conversion to yaml.MapSlice successfully - // fields such as "tools" in toolsets might pass the first check but - // fail to convert to MapSlice - if slice, ok := item.Value.(yaml.MapSlice); ok { - // Deprecated: convert authSources to authServices - if key == "authSources" { - key = "authServices" - } - transformed, err := transformDocs(key, slice) - if err != nil { - return nil, err - } - // encode per-doc - for _, doc := range transformed { - if err := encoder.Encode(doc); err != nil { - return nil, err - } - } - } else { - // invalid input will be ignored - // we don't want to throw error here since the config could - // be valid but with a different order such as: - // --- - // tools: - // - tool_a - // kind: toolsets - // --- - continue - } - } else { - // this doc is already v2, encode to buf - if err := encoder.Encode(input); err != nil { - return nil, err - } - break - } - } - } - return buf.Bytes(), nil -} - -// transformDocs transforms the configuration file from v1 format to v2 -// yaml.MapSlice will preserve the order in a map -func transformDocs(kind string, input yaml.MapSlice) ([]yaml.MapSlice, error) { - var transformed []yaml.MapSlice - for _, entry := range input { - entryName, ok := entry.Key.(string) - if !ok { - return nil, fmt.Errorf("unexpected non-string key for entry in '%s': %v", kind, entry.Key) - } - entryBody := ProcessValue(entry.Value, kind == "toolsets") - - currentTransformed := yaml.MapSlice{ - {Key: "kind", Value: kind}, - {Key: "name", Value: entryName}, - } - - // Merge the transformed body into our result - if bodySlice, ok := entryBody.(yaml.MapSlice); ok { - currentTransformed = append(currentTransformed, bodySlice...) - } else { - return nil, fmt.Errorf("unable to convert entryBody to MapSlice") - } - transformed = append(transformed, currentTransformed) - } - return transformed, nil -} - -// ProcessValue recursively looks for MapSlices to rename 'kind' -> 'type' -func ProcessValue(v any, isToolset bool) any { - switch val := v.(type) { - case yaml.MapSlice: - // creating a new MapSlice is safer for recursive transformation - newVal := make(yaml.MapSlice, len(val)) - for i, item := range val { - // Perform renaming - if item.Key == "kind" { - item.Key = "type" - } - // Recursive call for nested values (e.g., nested objects or lists) - item.Value = ProcessValue(item.Value, false) - newVal[i] = item - } - return newVal - case []any: - // Process lists: If it's a toolset top-level list, wrap it. - if isToolset { - return yaml.MapSlice{{Key: "tools", Value: val}} - } - // Otherwise, recurse into list items (to catch nested objects) - newVal := make([]any, len(val)) - for i := range val { - newVal[i] = ProcessValue(val[i], false) - } - return newVal - default: - return val - } -} - -// parseToolsFile parses the provided yaml into appropriate configs. -func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) { - var toolsFile ToolsFile - // Replace environment variables if found - output, err := parseEnv(string(raw)) - if err != nil { - return toolsFile, fmt.Errorf("error parsing environment variables: %s", err) - } - raw = []byte(output) - - raw, err = convertToolsFile(raw) - if err != nil { - return toolsFile, fmt.Errorf("error converting tools file: %s", err) - } - - // Parse contents - toolsFile.Sources, toolsFile.AuthServices, toolsFile.EmbeddingModels, toolsFile.Tools, toolsFile.Toolsets, toolsFile.Prompts, err = server.UnmarshalResourceConfig(ctx, raw) - if err != nil { - return toolsFile, err - } - return toolsFile, nil -} - -// mergeToolsFiles merges multiple ToolsFile structs into one. -// Detects and raises errors for resource conflicts in sources, authServices, tools, and toolsets. -// All resource names (sources, authServices, tools, toolsets) must be unique across all files. -func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) { - merged := ToolsFile{ - Sources: make(server.SourceConfigs), - AuthServices: make(server.AuthServiceConfigs), - EmbeddingModels: make(server.EmbeddingModelConfigs), - Tools: make(server.ToolConfigs), - Toolsets: make(server.ToolsetConfigs), - Prompts: make(server.PromptConfigs), - } - - var conflicts []string - - for fileIndex, file := range files { - // Check for conflicts and merge sources - for name, source := range file.Sources { - if _, exists := merged.Sources[name]; exists { - conflicts = append(conflicts, fmt.Sprintf("source '%s' (file #%d)", name, fileIndex+1)) - } else { - merged.Sources[name] = source - } - } - - // Check for conflicts and merge authServices - for name, authService := range file.AuthServices { - if _, exists := merged.AuthServices[name]; exists { - conflicts = append(conflicts, fmt.Sprintf("authService '%s' (file #%d)", name, fileIndex+1)) - } else { - merged.AuthServices[name] = authService - } - } - - // Check for conflicts and merge embeddingModels - for name, em := range file.EmbeddingModels { - if _, exists := merged.EmbeddingModels[name]; exists { - conflicts = append(conflicts, fmt.Sprintf("embedding model '%s' (file #%d)", name, fileIndex+1)) - } else { - merged.EmbeddingModels[name] = em - } - } - - // Check for conflicts and merge tools - for name, tool := range file.Tools { - if _, exists := merged.Tools[name]; exists { - conflicts = append(conflicts, fmt.Sprintf("tool '%s' (file #%d)", name, fileIndex+1)) - } else { - merged.Tools[name] = tool - } - } - - // Check for conflicts and merge toolsets - for name, toolset := range file.Toolsets { - if _, exists := merged.Toolsets[name]; exists { - conflicts = append(conflicts, fmt.Sprintf("toolset '%s' (file #%d)", name, fileIndex+1)) - } else { - merged.Toolsets[name] = toolset - } - } - - // Check for conflicts and merge prompts - for name, prompt := range file.Prompts { - if _, exists := merged.Prompts[name]; exists { - conflicts = append(conflicts, fmt.Sprintf("prompt '%s' (file #%d)", name, fileIndex+1)) - } else { - merged.Prompts[name] = prompt - } - } - } - - // If conflicts were detected, return an error - if len(conflicts) > 0 { - return ToolsFile{}, fmt.Errorf("resource conflicts detected:\n - %s\n\nPlease ensure each source, authService, tool, toolset and prompt has a unique name across all files", strings.Join(conflicts, "\n - ")) - } - - return merged, nil -} - -// loadAndMergeToolsFiles loads multiple YAML files and merges them -func loadAndMergeToolsFiles(ctx context.Context, filePaths []string) (ToolsFile, error) { - var toolsFiles []ToolsFile - - for _, filePath := range filePaths { - buf, err := os.ReadFile(filePath) - if err != nil { - return ToolsFile{}, fmt.Errorf("unable to read tool file at %q: %w", filePath, err) - } - - toolsFile, err := parseToolsFile(ctx, buf) - if err != nil { - return ToolsFile{}, fmt.Errorf("unable to parse tool file at %q: %w", filePath, err) - } - - toolsFiles = append(toolsFiles, toolsFile) - } - - mergedFile, err := mergeToolsFiles(toolsFiles...) - if err != nil { - return ToolsFile{}, fmt.Errorf("unable to merge tools files: %w", err) - } - - return mergedFile, nil -} - -// loadAndMergeToolsFolder loads all YAML files from a directory and merges them -func loadAndMergeToolsFolder(ctx context.Context, folderPath string) (ToolsFile, error) { - // Check if directory exists - info, err := os.Stat(folderPath) - if err != nil { - return ToolsFile{}, fmt.Errorf("unable to access tools folder at %q: %w", folderPath, err) - } - if !info.IsDir() { - return ToolsFile{}, fmt.Errorf("path %q is not a directory", folderPath) - } - - // Find all YAML files in the directory - pattern := filepath.Join(folderPath, "*.yaml") - yamlFiles, err := filepath.Glob(pattern) - if err != nil { - return ToolsFile{}, fmt.Errorf("error finding YAML files in %q: %w", folderPath, err) - } - - // Also find .yml files - ymlPattern := filepath.Join(folderPath, "*.yml") - ymlFiles, err := filepath.Glob(ymlPattern) - if err != nil { - return ToolsFile{}, fmt.Errorf("error finding YML files in %q: %w", folderPath, err) - } - - // Combine both file lists - allFiles := append(yamlFiles, ymlFiles...) - - if len(allFiles) == 0 { - return ToolsFile{}, fmt.Errorf("no YAML files found in directory %q", folderPath) - } - - // Use existing loadAndMergeToolsFiles function - return loadAndMergeToolsFiles(ctx, allFiles) -} - -func handleDynamicReload(ctx context.Context, toolsFile ToolsFile, s *server.Server) error { +func handleDynamicReload(ctx context.Context, toolsFile internal.ToolsFile, s *server.Server) error { logger, err := util.LoggerFromContext(ctx) if err != nil { panic(err) @@ -753,7 +157,7 @@ func handleDynamicReload(ctx context.Context, toolsFile ToolsFile, s *server.Ser // validateReloadEdits checks that the reloaded tools file configs can initialized without failing func validateReloadEdits( - ctx context.Context, toolsFile ToolsFile, + ctx context.Context, toolsFile internal.ToolsFile, ) (map[string]sources.Source, map[string]auth.AuthService, map[string]embeddingmodels.EmbeddingModel, map[string]tools.Tool, map[string]tools.Toolset, map[string]prompts.Prompt, map[string]prompts.Promptset, error, ) { logger, err := util.LoggerFromContext(ctx) @@ -879,18 +283,18 @@ func watchChanges(ctx context.Context, watchDirs map[string]bool, watchedFiles m case <-debounce.C: debounce.Stop() - var reloadedToolsFile ToolsFile + var reloadedToolsFile internal.ToolsFile if watchingFolder { logger.DebugContext(ctx, "Reloading tools folder.") - reloadedToolsFile, err = loadAndMergeToolsFolder(ctx, folderToWatch) + reloadedToolsFile, err = internal.LoadAndMergeToolsFolder(ctx, folderToWatch) if err != nil { logger.WarnContext(ctx, "error loading tools folder %s", err) continue } } else { logger.DebugContext(ctx, "Reloading tools file(s).") - reloadedToolsFile, err = loadAndMergeToolsFiles(ctx, slices.Collect(maps.Keys(watchedFiles))) + reloadedToolsFile, err = internal.LoadAndMergeToolsFiles(ctx, slices.Collect(maps.Keys(watchedFiles))) if err != nil { logger.WarnContext(ctx, "error loading tools files %s", err) continue @@ -934,184 +338,7 @@ func resolveWatcherInputs(toolsFile string, toolsFiles []string, toolsFolder str return watchDirs, watchedFiles } -func (cmd *Command) Config() server.ServerConfig { - return cmd.cfg -} - -func (cmd *Command) Out() io.Writer { - return cmd.outStream -} - -func (cmd *Command) Logger() log.Logger { - return cmd.logger -} - -func (cmd *Command) LoadConfig(ctx context.Context) error { - logger, err := util.LoggerFromContext(ctx) - if err != nil { - return err - } - - var allToolsFiles []ToolsFile - - // Load Prebuilt Configuration - - if len(cmd.prebuiltConfigs) > 0 { - slices.Sort(cmd.prebuiltConfigs) - sourcesList := strings.Join(cmd.prebuiltConfigs, ", ") - logMsg := fmt.Sprintf("Using prebuilt tool configurations for: %s", sourcesList) - logger.InfoContext(ctx, logMsg) - - for _, configName := range cmd.prebuiltConfigs { - buf, err := prebuiltconfigs.Get(configName) - if err != nil { - logger.ErrorContext(ctx, err.Error()) - return err - } - - // Parse into ToolsFile struct - parsed, err := parseToolsFile(ctx, buf) - if err != nil { - errMsg := fmt.Errorf("unable to parse prebuilt tool configuration for '%s': %w", configName, err) - logger.ErrorContext(ctx, errMsg.Error()) - return errMsg - } - allToolsFiles = append(allToolsFiles, parsed) - } - } - - // Determine if Custom Files should be loaded - // Check for explicit custom flags - isCustomConfigured := cmd.tools_file != "" || len(cmd.tools_files) > 0 || cmd.tools_folder != "" - - // Determine if default 'tools.yaml' should be used (No prebuilt AND No custom flags) - useDefaultToolsFile := len(cmd.prebuiltConfigs) == 0 && !isCustomConfigured - - if useDefaultToolsFile { - cmd.tools_file = "tools.yaml" - isCustomConfigured = true - } - - // Load Custom Configurations - if isCustomConfigured { - // Enforce exclusivity among custom flags (tools-file vs tools-files vs tools-folder) - if (cmd.tools_file != "" && len(cmd.tools_files) > 0) || - (cmd.tools_file != "" && cmd.tools_folder != "") || - (len(cmd.tools_files) > 0 && cmd.tools_folder != "") { - errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously") - logger.ErrorContext(ctx, errMsg.Error()) - return errMsg - } - - var customTools ToolsFile - var err error - - if len(cmd.tools_files) > 0 { - // Use tools-files - logger.InfoContext(ctx, fmt.Sprintf("Loading and merging %d tool configuration files", len(cmd.tools_files))) - customTools, err = loadAndMergeToolsFiles(ctx, cmd.tools_files) - } else if cmd.tools_folder != "" { - // Use tools-folder - logger.InfoContext(ctx, fmt.Sprintf("Loading and merging all YAML files from directory: %s", cmd.tools_folder)) - customTools, err = loadAndMergeToolsFolder(ctx, cmd.tools_folder) - } else { - // Use single file (tools-file or default `tools.yaml`) - buf, readFileErr := os.ReadFile(cmd.tools_file) - if readFileErr != nil { - errMsg := fmt.Errorf("unable to read tool file at %q: %w", cmd.tools_file, readFileErr) - logger.ErrorContext(ctx, errMsg.Error()) - return errMsg - } - customTools, err = parseToolsFile(ctx, buf) - if err != nil { - err = fmt.Errorf("unable to parse tool file at %q: %w", cmd.tools_file, err) - } - } - - if err != nil { - logger.ErrorContext(ctx, err.Error()) - return err - } - allToolsFiles = append(allToolsFiles, customTools) - } - - // Modify version string based on loaded configurations - if len(cmd.prebuiltConfigs) > 0 { - tag := "prebuilt" - if isCustomConfigured { - tag = "custom" - } - // cmd.prebuiltConfigs is already sorted above - for _, configName := range cmd.prebuiltConfigs { - cmd.cfg.Version += fmt.Sprintf("+%s.%s", tag, configName) - } - } - - // Merge Everything - // This will error if custom tools collide with prebuilt tools - finalToolsFile, err := mergeToolsFiles(allToolsFiles...) - if err != nil { - logger.ErrorContext(ctx, err.Error()) - return err - } - - cmd.cfg.SourceConfigs = finalToolsFile.Sources - cmd.cfg.AuthServiceConfigs = finalToolsFile.AuthServices - cmd.cfg.EmbeddingModelConfigs = finalToolsFile.EmbeddingModels - cmd.cfg.ToolConfigs = finalToolsFile.Tools - cmd.cfg.ToolsetConfigs = finalToolsFile.Toolsets - cmd.cfg.PromptConfigs = finalToolsFile.Prompts - - return nil -} - -func (cmd *Command) Setup(ctx context.Context) (context.Context, func(context.Context) error, error) { - // If stdio, set logger's out stream (usually DEBUG and INFO logs) to errStream - loggerOut := cmd.outStream - if cmd.cfg.Stdio { - loggerOut = cmd.errStream - } - - // Handle logger separately from config - logger, err := log.NewLogger(cmd.cfg.LoggingFormat.String(), cmd.cfg.LogLevel.String(), loggerOut, cmd.errStream) - if err != nil { - return ctx, nil, fmt.Errorf("unable to initialize logger: %w", err) - } - cmd.logger = logger - - ctx = util.WithLogger(ctx, cmd.logger) - - // Set up OpenTelemetry - otelShutdown, err := telemetry.SetupOTel(ctx, cmd.cfg.Version, cmd.cfg.TelemetryOTLP, cmd.cfg.TelemetryGCP, cmd.cfg.TelemetryServiceName) - if err != nil { - errMsg := fmt.Errorf("error setting up OpenTelemetry: %w", err) - cmd.logger.ErrorContext(ctx, errMsg.Error()) - return ctx, nil, errMsg - } - - shutdownFunc := func(ctx context.Context) error { - err := otelShutdown(ctx) - if err != nil { - errMsg := fmt.Errorf("error shutting down OpenTelemetry: %w", err) - cmd.logger.ErrorContext(ctx, errMsg.Error()) - return err - } - return nil - } - - instrumentation, err := telemetry.CreateTelemetryInstrumentation(cmd.cfg.Version) - if err != nil { - errMsg := fmt.Errorf("unable to create telemetry instrumentation: %w", err) - cmd.logger.ErrorContext(ctx, errMsg.Error()) - return ctx, shutdownFunc, errMsg - } - - ctx = util.WithInstrumentation(ctx, instrumentation) - - return ctx, shutdownFunc, nil -} - -func run(cmd *Command) error { +func run(cmd *cobra.Command, opts *internal.ToolboxOptions) error { ctx, cancel := context.WithCancel(cmd.Context()) defer cancel() @@ -1128,14 +355,14 @@ func run(cmd *Command) error { } switch s { case syscall.SIGINT: - cmd.logger.DebugContext(sCtx, "Received SIGINT signal to shutdown.") + opts.Logger.DebugContext(sCtx, "Received SIGINT signal to shutdown.") case syscall.SIGTERM: - cmd.logger.DebugContext(sCtx, "Sending SIGTERM signal to shutdown.") + opts.Logger.DebugContext(sCtx, "Sending SIGTERM signal to shutdown.") } cancel() }(ctx) - ctx, shutdown, err := cmd.Setup(ctx) + ctx, shutdown, err := opts.Setup(ctx) if err != nil { return err } @@ -1143,24 +370,25 @@ func run(cmd *Command) error { _ = shutdown(ctx) }() - if err := cmd.LoadConfig(ctx); err != nil { + isCustomConfigured, err := opts.LoadConfig(ctx) + if err != nil { return err } // start server - s, err := server.NewServer(ctx, cmd.cfg) + s, err := server.NewServer(ctx, opts.Cfg) if err != nil { errMsg := fmt.Errorf("toolbox failed to initialize: %w", err) - cmd.logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } // run server in background srvErr := make(chan error) - if cmd.cfg.Stdio { + if opts.Cfg.Stdio { go func() { defer close(srvErr) - err = s.ServeStdio(ctx, cmd.inStream, cmd.outStream) + err = s.ServeStdio(ctx, opts.IOStreams.In, opts.IOStreams.Out) if err != nil { srvErr <- err } @@ -1169,12 +397,12 @@ func run(cmd *Command) error { err = s.Listen(ctx) if err != nil { errMsg := fmt.Errorf("toolbox failed to start listener: %w", err) - cmd.logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } - cmd.logger.InfoContext(ctx, "Server ready to serve!") - if cmd.cfg.UI { - cmd.logger.InfoContext(ctx, fmt.Sprintf("Toolbox UI is up and running at: http://%s:%d/ui", cmd.cfg.Address, cmd.cfg.Port)) + opts.Logger.InfoContext(ctx, "Server ready to serve!") + if opts.Cfg.UI { + opts.Logger.InfoContext(ctx, fmt.Sprintf("Toolbox UI is up and running at: http://%s:%d/ui", opts.Cfg.Address, opts.Cfg.Port)) } go func() { @@ -1186,11 +414,8 @@ func run(cmd *Command) error { }() } - // Determine if Custom Files are configured (re-check as loadAndMergeConfig might have updated defaults) - isCustomConfigured := cmd.tools_file != "" || len(cmd.tools_files) > 0 || cmd.tools_folder != "" - - if isCustomConfigured && !cmd.cfg.DisableReload { - watchDirs, watchedFiles := resolveWatcherInputs(cmd.tools_file, cmd.tools_files, cmd.tools_folder) + if isCustomConfigured && !opts.Cfg.DisableReload { + watchDirs, watchedFiles := resolveWatcherInputs(opts.ToolsFile, opts.ToolsFiles, opts.ToolsFolder) // start watching the file(s) or folder for changes to trigger dynamic reloading go watchChanges(ctx, watchDirs, watchedFiles, s) } @@ -1200,13 +425,13 @@ func run(cmd *Command) error { case err := <-srvErr: if err != nil { errMsg := fmt.Errorf("toolbox crashed with the following error: %w", err) - cmd.logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } case <-ctx.Done(): shutdownContext, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cmd.logger.WarnContext(shutdownContext, "Shutting down gracefully...") + opts.Logger.WarnContext(shutdownContext, "Shutting down gracefully...") err := s.Shutdown(shutdownContext) if err == context.DeadlineExceeded { return fmt.Errorf("graceful shutdown timed out... forcing exit") diff --git a/cmd/root_test.go b/cmd/root_test.go index f26bd1706a..e85aaa3d26 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -31,22 +31,12 @@ import ( "github.com/google/go-cmp/cmp" - "github.com/googleapis/genai-toolbox/internal/auth/google" - "github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini" + "github.com/googleapis/genai-toolbox/cmd/internal" "github.com/googleapis/genai-toolbox/internal/log" - "github.com/googleapis/genai-toolbox/internal/prebuiltconfigs" - "github.com/googleapis/genai-toolbox/internal/prompts" - "github.com/googleapis/genai-toolbox/internal/prompts/custom" "github.com/googleapis/genai-toolbox/internal/server" - cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http" "github.com/googleapis/genai-toolbox/internal/telemetry" "github.com/googleapis/genai-toolbox/internal/testutils" - "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/http" - "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql" "github.com/googleapis/genai-toolbox/internal/util" - "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/spf13/cobra" ) @@ -76,15 +66,16 @@ func withDefaults(c server.ServerConfig) server.ServerConfig { return c } -func invokeCommand(args []string) (*Command, string, error) { - c := NewCommand() +func invokeCommand(args []string) (*cobra.Command, *internal.ToolboxOptions, string, error) { + buf := new(bytes.Buffer) + opts := internal.NewToolboxOptions(internal.WithIOStreams(buf, buf)) + c := NewCommand(opts) // Keep the test output quiet c.SilenceUsage = true c.SilenceErrors = true // Capture output - buf := new(bytes.Buffer) c.SetOut(buf) c.SetErr(buf) c.SetArgs(args) @@ -96,22 +87,23 @@ func invokeCommand(args []string) (*Command, string, error) { err := c.Execute() - return c, buf.String(), err + return c, opts, buf.String(), err } // invokeCommandWithContext executes the command with a context and returns the captured output. -func invokeCommandWithContext(ctx context.Context, args []string) (*Command, string, error) { - // Capture output using a buffer +func invokeCommandWithContext(ctx context.Context, args []string) (*cobra.Command, *internal.ToolboxOptions, string, error) { buf := new(bytes.Buffer) - c := NewCommand(WithStreams(buf, buf)) + opts := internal.NewToolboxOptions(internal.WithIOStreams(buf, buf)) + c := NewCommand(opts) + // Capture output using a buffer c.SetArgs(args) c.SilenceUsage = true c.SilenceErrors = true c.SetContext(ctx) err := c.Execute() - return c, buf.String(), err + return c, opts, buf.String(), err } func TestVersion(t *testing.T) { @@ -121,7 +113,7 @@ func TestVersion(t *testing.T) { } want := strings.TrimSpace(string(data)) - _, got, err := invokeCommand([]string{"--version"}) + _, _, got, err := invokeCommand([]string{"--version"}) if err != nil { t.Fatalf("error invoking command: %s", err) } @@ -243,79 +235,13 @@ func TestServerConfigFlags(t *testing.T) { } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - c, _, err := invokeCommand(tc.args) + _, opts, _, err := invokeCommand(tc.args) if err != nil { t.Fatalf("unexpected error invoking command: %s", err) } - if !cmp.Equal(c.cfg, tc.want) { - t.Fatalf("got %v, want %v", c.cfg, tc.want) - } - }) - } -} - -func TestParseEnv(t *testing.T) { - tcs := []struct { - desc string - env map[string]string - in string - want string - err bool - errString string - }{ - { - desc: "without default without env", - in: "${FOO}", - want: "", - err: true, - errString: `environment variable not found: "FOO"`, - }, - { - desc: "without default with env", - env: map[string]string{ - "FOO": "bar", - }, - in: "${FOO}", - want: "bar", - }, - { - desc: "with empty default", - in: "${FOO:}", - want: "", - }, - { - desc: "with default", - in: "${FOO:bar}", - want: "bar", - }, - { - desc: "with default with env", - env: map[string]string{ - "FOO": "hello", - }, - in: "${FOO:bar}", - want: "hello", - }, - } - for _, tc := range tcs { - t.Run(tc.desc, func(t *testing.T) { - if tc.env != nil { - for k, v := range tc.env { - t.Setenv(k, v) - } - } - got, err := parseEnv(tc.in) - if tc.err { - if err == nil { - t.Fatalf("expected error not found") - } - if tc.errString != err.Error() { - t.Fatalf("incorrect error string: got %s, want %s", err, tc.errString) - } - } - if tc.want != got { - t.Fatalf("unexpected want: got %s, want %s", got, tc.want) + if !cmp.Equal(opts.Cfg, tc.want) { + t.Fatalf("got %v, want %v", opts.Cfg, tc.want) } }) } @@ -350,12 +276,12 @@ func TestToolFileFlag(t *testing.T) { } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - c, _, err := invokeCommand(tc.args) + _, opts, _, err := invokeCommand(tc.args) if err != nil { t.Fatalf("unexpected error invoking command: %s", err) } - if c.tools_file != tc.want { - t.Fatalf("got %v, want %v", c.cfg, tc.want) + if opts.ToolsFile != tc.want { + t.Fatalf("got %v, want %v", opts.Cfg, tc.want) } }) } @@ -385,12 +311,12 @@ func TestToolsFilesFlag(t *testing.T) { } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - c, _, err := invokeCommand(tc.args) + _, opts, _, err := invokeCommand(tc.args) if err != nil { t.Fatalf("unexpected error invoking command: %s", err) } - if diff := cmp.Diff(c.tools_files, tc.want); diff != "" { - t.Fatalf("got %v, want %v", c.tools_files, tc.want) + if diff := cmp.Diff(opts.ToolsFiles, tc.want); diff != "" { + t.Fatalf("got %v, want %v", opts.ToolsFiles, tc.want) } }) } @@ -415,12 +341,12 @@ func TestToolsFolderFlag(t *testing.T) { } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - c, _, err := invokeCommand(tc.args) + _, opts, _, err := invokeCommand(tc.args) if err != nil { t.Fatalf("unexpected error invoking command: %s", err) } - if c.tools_folder != tc.want { - t.Fatalf("got %v, want %v", c.tools_folder, tc.want) + if opts.ToolsFolder != tc.want { + t.Fatalf("got %v, want %v", opts.ToolsFolder, tc.want) } }) } @@ -455,12 +381,12 @@ func TestPrebuiltFlag(t *testing.T) { } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - c, _, err := invokeCommand(tc.args) + _, opts, _, err := invokeCommand(tc.args) if err != nil { t.Fatalf("unexpected error invoking command: %s", err) } - if diff := cmp.Diff(c.prebuiltConfigs, tc.want); diff != "" { - t.Fatalf("got %v, want %v, diff %s", c.prebuiltConfigs, tc.want, diff) + if diff := cmp.Diff(opts.PrebuiltConfigs, tc.want); diff != "" { + t.Fatalf("got %v, want %v, diff %s", opts.PrebuiltConfigs, tc.want, diff) } }) } @@ -482,7 +408,7 @@ func TestFailServerConfigFlags(t *testing.T) { } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - _, _, err := invokeCommand(tc.args) + _, _, _, err := invokeCommand(tc.args) if err == nil { t.Fatalf("expected an error, but got nil") } @@ -491,11 +417,11 @@ func TestFailServerConfigFlags(t *testing.T) { } func TestDefaultLoggingFormat(t *testing.T) { - c, _, err := invokeCommand([]string{}) + _, opts, _, err := invokeCommand([]string{}) if err != nil { t.Fatalf("unexpected error invoking command: %s", err) } - got := c.cfg.LoggingFormat.String() + got := opts.Cfg.LoggingFormat.String() want := "standard" if got != want { t.Fatalf("unexpected default logging format flag: got %v, want %v", got, want) @@ -503,1377 +429,17 @@ func TestDefaultLoggingFormat(t *testing.T) { } func TestDefaultLogLevel(t *testing.T) { - c, _, err := invokeCommand([]string{}) + _, opts, _, err := invokeCommand([]string{}) if err != nil { t.Fatalf("unexpected error invoking command: %s", err) } - got := c.cfg.LogLevel.String() + got := opts.Cfg.LogLevel.String() want := "info" if got != want { t.Fatalf("unexpected default log level flag: got %v, want %v", got, want) } } -func TestConvertToolsFile(t *testing.T) { - tcs := []struct { - desc string - in string - want string - isErr bool - errStr string - }{ - { - desc: "basic convert", - in: ` - sources: - my-pg-instance: - kind: cloud-sql-postgres - project: my-project - region: my-region - instance: my-instance - database: my_db - user: my_user - password: my_pass - authServices: - my-google-auth: - kind: google - clientId: testing-id - tools: - example_tool: - kind: postgres-sql - source: my-pg-instance - description: some description - statement: SELECT * FROM SQL_STATEMENT; - parameters: - - name: country - type: string - description: some description - toolsets: - example_toolset: - - example_tool - prompts: - code_review: - description: ask llm to analyze code quality - messages: - - content: "please review the following code for quality: {{.code}}" - arguments: - - name: code - description: the code to review - embeddingModels: - gemini-model: - kind: gemini - model: gemini-embedding-001 - apiKey: some-key - dimension: 768`, - want: `kind: sources -name: my-pg-instance -type: cloud-sql-postgres -project: my-project -region: my-region -instance: my-instance -database: my_db -user: my_user -password: my_pass ---- -kind: authServices -name: my-google-auth -type: google -clientId: testing-id ---- -kind: tools -name: example_tool -type: postgres-sql -source: my-pg-instance -description: some description -statement: SELECT * FROM SQL_STATEMENT; -parameters: -- name: country - type: string - description: some description ---- -kind: toolsets -name: example_toolset -tools: -- example_tool ---- -kind: prompts -name: code_review -description: ask llm to analyze code quality -messages: -- content: "please review the following code for quality: {{.code}}" -arguments: -- name: code - description: the code to review ---- -kind: embeddingModels -name: gemini-model -type: gemini -model: gemini-embedding-001 -apiKey: some-key -dimension: 768 -`, - }, - { - desc: "preserve resource order", - in: ` - tools: - example_tool: - kind: postgres-sql - source: my-pg-instance - description: some description - statement: SELECT * FROM SQL_STATEMENT; - parameters: - - name: country - type: string - description: some description - sources: - my-pg-instance: - kind: cloud-sql-postgres - project: my-project - region: my-region - instance: my-instance - database: my_db - user: my_user - password: my_pass - authServices: - my-google-auth: - kind: google - clientId: testing-id - toolsets: - example_toolset: - - example_tool - authSources: - my-google-auth2: - kind: google - clientId: testing-id`, - want: `kind: tools -name: example_tool -type: postgres-sql -source: my-pg-instance -description: some description -statement: SELECT * FROM SQL_STATEMENT; -parameters: -- name: country - type: string - description: some description ---- -kind: sources -name: my-pg-instance -type: cloud-sql-postgres -project: my-project -region: my-region -instance: my-instance -database: my_db -user: my_user -password: my_pass ---- -kind: authServices -name: my-google-auth -type: google -clientId: testing-id ---- -kind: toolsets -name: example_toolset -tools: -- example_tool ---- -kind: authServices -name: my-google-auth2 -type: google -clientId: testing-id -`, - }, - { - desc: "convert combination of v1 and v2", - in: ` - sources: - my-pg-instance: - kind: cloud-sql-postgres - project: my-project - region: my-region - instance: my-instance - database: my_db - user: my_user - password: my_pass - authServices: - my-google-auth: - kind: google - clientId: testing-id - tools: - example_tool: - kind: postgres-sql - source: my-pg-instance - description: some description - statement: SELECT * FROM SQL_STATEMENT; - parameters: - - name: country - type: string - description: some description - toolsets: - example_toolset: - - example_tool - prompts: - code_review: - description: ask llm to analyze code quality - messages: - - content: "please review the following code for quality: {{.code}}" - arguments: - - name: code - description: the code to review - embeddingModels: - gemini-model: - kind: gemini - model: gemini-embedding-001 - apiKey: some-key - dimension: 768 ---- - kind: sources - name: my-pg-instance2 - type: cloud-sql-postgres - project: my-project - region: my-region - instance: my-instance ---- - kind: authServices - name: my-google-auth2 - type: google - clientId: testing-id ---- - kind: tools - name: example_tool2 - type: postgres-sql - source: my-pg-instance - description: some description - statement: SELECT * FROM SQL_STATEMENT; - parameters: - - name: country - type: string - description: some description ---- - kind: toolsets - name: example_toolset2 - tools: - - example_tool ---- - tools: - - example_tool - kind: toolsets - name: example_toolset3 ---- - kind: prompts - name: code_review2 - description: ask llm to analyze code quality - messages: - - content: "please review the following code for quality: {{.code}}" - arguments: - - name: code - description: the code to review ---- - kind: embeddingModels - name: gemini-model2 - type: gemini`, - want: `kind: sources -name: my-pg-instance -type: cloud-sql-postgres -project: my-project -region: my-region -instance: my-instance -database: my_db -user: my_user -password: my_pass ---- -kind: authServices -name: my-google-auth -type: google -clientId: testing-id ---- -kind: tools -name: example_tool -type: postgres-sql -source: my-pg-instance -description: some description -statement: SELECT * FROM SQL_STATEMENT; -parameters: -- name: country - type: string - description: some description ---- -kind: toolsets -name: example_toolset -tools: -- example_tool ---- -kind: prompts -name: code_review -description: ask llm to analyze code quality -messages: -- content: "please review the following code for quality: {{.code}}" -arguments: -- name: code - description: the code to review ---- -kind: embeddingModels -name: gemini-model -type: gemini -model: gemini-embedding-001 -apiKey: some-key -dimension: 768 ---- -kind: sources -name: my-pg-instance2 -type: cloud-sql-postgres -project: my-project -region: my-region -instance: my-instance ---- -kind: authServices -name: my-google-auth2 -type: google -clientId: testing-id ---- -kind: tools -name: example_tool2 -type: postgres-sql -source: my-pg-instance -description: some description -statement: SELECT * FROM SQL_STATEMENT; -parameters: -- name: country - type: string - description: some description ---- -kind: toolsets -name: example_toolset2 -tools: -- example_tool ---- -tools: -- example_tool -kind: toolsets -name: example_toolset3 ---- -kind: prompts -name: code_review2 -description: ask llm to analyze code quality -messages: -- content: "please review the following code for quality: {{.code}}" -arguments: -- name: code - description: the code to review ---- -kind: embeddingModels -name: gemini-model2 -type: gemini -`, - }, - { - desc: "no convertion needed", - in: `kind: sources -name: my-pg-instance -type: cloud-sql-postgres -project: my-project -region: my-region -instance: my-instance -database: my_db -user: my_user -password: my_pass ---- -kind: tools -name: example_tool -type: postgres-sql -source: my-pg-instance -description: some description -statement: SELECT * FROM SQL_STATEMENT; -parameters: -- name: country - type: string - description: some description ---- -kind: toolsets -name: example_toolset -tools: -- example_tool`, - want: `kind: sources -name: my-pg-instance -type: cloud-sql-postgres -project: my-project -region: my-region -instance: my-instance -database: my_db -user: my_user -password: my_pass ---- -kind: tools -name: example_tool -type: postgres-sql -source: my-pg-instance -description: some description -statement: SELECT * FROM SQL_STATEMENT; -parameters: -- name: country - type: string - description: some description ---- -kind: toolsets -name: example_toolset -tools: -- example_tool -`, - }, - { - desc: "invalid source", - in: `sources: invalid`, - want: "", - }, - { - desc: "invalid toolset", - in: `toolsets: invalid`, - want: "", - }, - } - for _, tc := range tcs { - t.Run(tc.desc, func(t *testing.T) { - output, err := convertToolsFile([]byte(tc.in)) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - - if diff := cmp.Diff(string(output), tc.want); diff != "" { - t.Fatalf("incorrect toolsets parse: diff %v", diff) - } - }) - } -} - -func TestParseToolFile(t *testing.T) { - ctx, err := testutils.ContextWithNewLogger() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - tcs := []struct { - description string - in string - wantToolsFile ToolsFile - }{ - { - description: "basic example tools file v1", - in: ` - sources: - my-pg-instance: - kind: cloud-sql-postgres - project: my-project - region: my-region - instance: my-instance - database: my_db - user: my_user - password: my_pass - tools: - example_tool: - kind: postgres-sql - source: my-pg-instance - description: some description - statement: | - SELECT * FROM SQL_STATEMENT; - parameters: - - name: country - type: string - description: some description - toolsets: - example_toolset: - - example_tool - `, - wantToolsFile: ToolsFile{ - Sources: server.SourceConfigs{ - "my-pg-instance": cloudsqlpgsrc.Config{ - Name: "my-pg-instance", - Type: cloudsqlpgsrc.SourceType, - Project: "my-project", - Region: "my-region", - Instance: "my-instance", - IPType: "public", - Database: "my_db", - User: "my_user", - Password: "my_pass", - }, - }, - Tools: server.ToolConfigs{ - "example_tool": postgressql.Config{ - Name: "example_tool", - Type: "postgres-sql", - Source: "my-pg-instance", - Description: "some description", - Statement: "SELECT * FROM SQL_STATEMENT;\n", - Parameters: []parameters.Parameter{ - parameters.NewStringParameter("country", "some description"), - }, - AuthRequired: []string{}, - }, - }, - Toolsets: server.ToolsetConfigs{ - "example_toolset": tools.ToolsetConfig{ - Name: "example_toolset", - ToolNames: []string{"example_tool"}, - }, - }, - AuthServices: nil, - Prompts: nil, - }, - }, - { - description: "basic example tools file v2", - in: ` - kind: sources - name: my-pg-instance - type: cloud-sql-postgres - project: my-project - region: my-region - instance: my-instance - database: my_db - user: my_user - password: my_pass ---- - kind: authServices - name: my-google-auth - type: google - clientId: testing-id ---- - kind: embeddingModels - name: gemini-model - type: gemini - model: gemini-embedding-001 - apiKey: some-key - dimension: 768 ---- - kind: tools - name: example_tool - type: postgres-sql - source: my-pg-instance - description: some description - statement: | - SELECT * FROM SQL_STATEMENT; - parameters: - - name: country - type: string - description: some description ---- - kind: toolsets - name: example_toolset - tools: - - example_tool ---- - kind: prompts - name: code_review - description: ask llm to analyze code quality - messages: - - content: "please review the following code for quality: {{.code}}" - arguments: - - name: code - description: the code to review - `, - wantToolsFile: ToolsFile{ - Sources: server.SourceConfigs{ - "my-pg-instance": cloudsqlpgsrc.Config{ - Name: "my-pg-instance", - Type: cloudsqlpgsrc.SourceType, - Project: "my-project", - Region: "my-region", - Instance: "my-instance", - IPType: "public", - Database: "my_db", - User: "my_user", - Password: "my_pass", - }, - }, - AuthServices: server.AuthServiceConfigs{ - "my-google-auth": google.Config{ - Name: "my-google-auth", - Type: google.AuthServiceType, - ClientID: "testing-id", - }, - }, - EmbeddingModels: server.EmbeddingModelConfigs{ - "gemini-model": gemini.Config{ - Name: "gemini-model", - Type: gemini.EmbeddingModelType, - Model: "gemini-embedding-001", - ApiKey: "some-key", - Dimension: 768, - }, - }, - Tools: server.ToolConfigs{ - "example_tool": postgressql.Config{ - Name: "example_tool", - Type: "postgres-sql", - Source: "my-pg-instance", - Description: "some description", - Statement: "SELECT * FROM SQL_STATEMENT;\n", - Parameters: []parameters.Parameter{ - parameters.NewStringParameter("country", "some description"), - }, - AuthRequired: []string{}, - }, - }, - Toolsets: server.ToolsetConfigs{ - "example_toolset": tools.ToolsetConfig{ - Name: "example_toolset", - ToolNames: []string{"example_tool"}, - }, - }, - Prompts: server.PromptConfigs{ - "code_review": &custom.Config{ - Name: "code_review", - Description: "ask llm to analyze code quality", - Arguments: prompts.Arguments{ - {Parameter: parameters.NewStringParameter("code", "the code to review")}, - }, - Messages: []prompts.Message{ - {Role: "user", Content: "please review the following code for quality: {{.code}}"}, - }, - }, - }, - }, - }, - { - description: "only prompts", - in: ` - kind: prompts - name: my-prompt - description: A prompt template for data analysis. - arguments: - - name: country - description: The country to analyze. - messages: - - content: Analyze the data for {{.country}}. - `, - wantToolsFile: ToolsFile{ - Sources: nil, - AuthServices: nil, - Tools: nil, - Toolsets: nil, - Prompts: server.PromptConfigs{ - "my-prompt": &custom.Config{ - Name: "my-prompt", - Description: "A prompt template for data analysis.", - Arguments: prompts.Arguments{ - {Parameter: parameters.NewStringParameter("country", "The country to analyze.")}, - }, - Messages: []prompts.Message{ - {Role: "user", Content: "Analyze the data for {{.country}}."}, - }, - }, - }, - }, - }, - } - for _, tc := range tcs { - t.Run(tc.description, func(t *testing.T) { - toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in)) - if err != nil { - t.Fatalf("failed to parse input: %v", err) - } - if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" { - t.Fatalf("incorrect sources parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" { - t.Fatalf("incorrect authServices parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" { - t.Fatalf("incorrect tools parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.Toolsets, toolsFile.Toolsets); diff != "" { - t.Fatalf("incorrect toolsets parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.Prompts, toolsFile.Prompts); diff != "" { - t.Fatalf("incorrect prompts parse: diff %v", diff) - } - }) - } - -} - -func TestParseToolFileWithAuth(t *testing.T) { - ctx, err := testutils.ContextWithNewLogger() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - tcs := []struct { - description string - in string - wantToolsFile ToolsFile - }{ - { - description: "basic example", - in: ` - kind: sources - name: my-pg-instance - type: cloud-sql-postgres - project: my-project - region: my-region - instance: my-instance - database: my_db - user: my_user - password: my_pass ---- - kind: authServices - name: my-google-service - type: google - clientId: my-client-id ---- - kind: authServices - name: other-google-service - type: google - clientId: other-client-id ---- - kind: tools - name: example_tool - type: postgres-sql - source: my-pg-instance - description: some description - statement: | - SELECT * FROM SQL_STATEMENT; - parameters: - - name: country - type: string - description: some description - - name: id - type: integer - description: user id - authServices: - - name: my-google-service - field: user_id - - name: email - type: string - description: user email - authServices: - - name: my-google-service - field: email - - name: other-google-service - field: other_email ---- - kind: toolsets - name: example_toolset - tools: - - example_tool - `, - wantToolsFile: ToolsFile{ - Sources: server.SourceConfigs{ - "my-pg-instance": cloudsqlpgsrc.Config{ - Name: "my-pg-instance", - Type: cloudsqlpgsrc.SourceType, - Project: "my-project", - Region: "my-region", - Instance: "my-instance", - IPType: "public", - Database: "my_db", - User: "my_user", - Password: "my_pass", - }, - }, - AuthServices: server.AuthServiceConfigs{ - "my-google-service": google.Config{ - Name: "my-google-service", - Type: google.AuthServiceType, - ClientID: "my-client-id", - }, - "other-google-service": google.Config{ - Name: "other-google-service", - Type: google.AuthServiceType, - ClientID: "other-client-id", - }, - }, - Tools: server.ToolConfigs{ - "example_tool": postgressql.Config{ - Name: "example_tool", - Type: "postgres-sql", - Source: "my-pg-instance", - Description: "some description", - Statement: "SELECT * FROM SQL_STATEMENT;\n", - AuthRequired: []string{}, - Parameters: []parameters.Parameter{ - parameters.NewStringParameter("country", "some description"), - parameters.NewIntParameterWithAuth("id", "user id", []parameters.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}), - parameters.NewStringParameterWithAuth("email", "user email", []parameters.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}), - }, - }, - }, - Toolsets: server.ToolsetConfigs{ - "example_toolset": tools.ToolsetConfig{ - Name: "example_toolset", - ToolNames: []string{"example_tool"}, - }, - }, - Prompts: nil, - }, - }, - { - description: "basic example with authSources", - in: ` - sources: - my-pg-instance: - kind: cloud-sql-postgres - project: my-project - region: my-region - instance: my-instance - database: my_db - user: my_user - password: my_pass - authSources: - my-google-service: - kind: google - clientId: my-client-id - other-google-service: - kind: google - clientId: other-client-id - - tools: - example_tool: - kind: postgres-sql - source: my-pg-instance - description: some description - statement: | - SELECT * FROM SQL_STATEMENT; - parameters: - - name: country - type: string - description: some description - - name: id - type: integer - description: user id - authSources: - - name: my-google-service - field: user_id - - name: email - type: string - description: user email - authSources: - - name: my-google-service - field: email - - name: other-google-service - field: other_email - - toolsets: - example_toolset: - - example_tool - `, - wantToolsFile: ToolsFile{ - Sources: server.SourceConfigs{ - "my-pg-instance": cloudsqlpgsrc.Config{ - Name: "my-pg-instance", - Type: cloudsqlpgsrc.SourceType, - Project: "my-project", - Region: "my-region", - Instance: "my-instance", - IPType: "public", - Database: "my_db", - User: "my_user", - Password: "my_pass", - }, - }, - AuthServices: server.AuthServiceConfigs{ - "my-google-service": google.Config{ - Name: "my-google-service", - Type: google.AuthServiceType, - ClientID: "my-client-id", - }, - "other-google-service": google.Config{ - Name: "other-google-service", - Type: google.AuthServiceType, - ClientID: "other-client-id", - }, - }, - Tools: server.ToolConfigs{ - "example_tool": postgressql.Config{ - Name: "example_tool", - Type: "postgres-sql", - Source: "my-pg-instance", - Description: "some description", - Statement: "SELECT * FROM SQL_STATEMENT;\n", - AuthRequired: []string{}, - Parameters: []parameters.Parameter{ - parameters.NewStringParameter("country", "some description"), - parameters.NewIntParameterWithAuth("id", "user id", []parameters.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}), - parameters.NewStringParameterWithAuth("email", "user email", []parameters.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}), - }, - }, - }, - Toolsets: server.ToolsetConfigs{ - "example_toolset": tools.ToolsetConfig{ - Name: "example_toolset", - ToolNames: []string{"example_tool"}, - }, - }, - Prompts: nil, - }, - }, - { - description: "basic example with authRequired", - in: ` - kind: sources - name: my-pg-instance - type: cloud-sql-postgres - project: my-project - region: my-region - instance: my-instance - database: my_db - user: my_user - password: my_pass ---- - kind: authServices - name: my-google-service - type: google - clientId: my-client-id ---- - kind: authServices - name: other-google-service - type: google - clientId: other-client-id ---- - kind: tools - name: example_tool - type: postgres-sql - source: my-pg-instance - description: some description - statement: | - SELECT * FROM SQL_STATEMENT; - authRequired: - - my-google-service - parameters: - - name: country - type: string - description: some description - - name: id - type: integer - description: user id - authServices: - - name: my-google-service - field: user_id - - name: email - type: string - description: user email - authServices: - - name: my-google-service - field: email - - name: other-google-service - field: other_email ---- - kind: toolsets - name: example_toolset - tools: - - example_tool - `, - wantToolsFile: ToolsFile{ - Sources: server.SourceConfigs{ - "my-pg-instance": cloudsqlpgsrc.Config{ - Name: "my-pg-instance", - Type: cloudsqlpgsrc.SourceType, - Project: "my-project", - Region: "my-region", - Instance: "my-instance", - IPType: "public", - Database: "my_db", - User: "my_user", - Password: "my_pass", - }, - }, - AuthServices: server.AuthServiceConfigs{ - "my-google-service": google.Config{ - Name: "my-google-service", - Type: google.AuthServiceType, - ClientID: "my-client-id", - }, - "other-google-service": google.Config{ - Name: "other-google-service", - Type: google.AuthServiceType, - ClientID: "other-client-id", - }, - }, - Tools: server.ToolConfigs{ - "example_tool": postgressql.Config{ - Name: "example_tool", - Type: "postgres-sql", - Source: "my-pg-instance", - Description: "some description", - Statement: "SELECT * FROM SQL_STATEMENT;\n", - AuthRequired: []string{"my-google-service"}, - Parameters: []parameters.Parameter{ - parameters.NewStringParameter("country", "some description"), - parameters.NewIntParameterWithAuth("id", "user id", []parameters.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}), - parameters.NewStringParameterWithAuth("email", "user email", []parameters.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}), - }, - }, - }, - Toolsets: server.ToolsetConfigs{ - "example_toolset": tools.ToolsetConfig{ - Name: "example_toolset", - ToolNames: []string{"example_tool"}, - }, - }, - Prompts: nil, - }, - }, - } - for _, tc := range tcs { - t.Run(tc.description, func(t *testing.T) { - toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in)) - if err != nil { - t.Fatalf("failed to parse input: %v", err) - } - if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" { - t.Fatalf("incorrect sources parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" { - t.Fatalf("incorrect authServices parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" { - t.Fatalf("incorrect tools parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.Toolsets, toolsFile.Toolsets); diff != "" { - t.Fatalf("incorrect toolsets parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.Prompts, toolsFile.Prompts); diff != "" { - t.Fatalf("incorrect prompts parse: diff %v", diff) - } - }) - } - -} - -func TestEnvVarReplacement(t *testing.T) { - ctx, err := testutils.ContextWithNewLogger() - t.Setenv("TestHeader", "ACTUAL_HEADER") - t.Setenv("API_KEY", "ACTUAL_API_KEY") - t.Setenv("clientId", "ACTUAL_CLIENT_ID") - t.Setenv("clientId2", "ACTUAL_CLIENT_ID_2") - t.Setenv("toolset_name", "ACTUAL_TOOLSET_NAME") - t.Setenv("cat_string", "cat") - t.Setenv("food_string", "food") - t.Setenv("TestHeader", "ACTUAL_HEADER") - t.Setenv("prompt_name", "ACTUAL_PROMPT_NAME") - t.Setenv("prompt_content", "ACTUAL_CONTENT") - - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - tcs := []struct { - description string - in string - wantToolsFile ToolsFile - }{ - { - description: "file with env var example", - in: ` - sources: - my-http-instance: - kind: http - baseUrl: http://test_server/ - timeout: 10s - headers: - Authorization: ${TestHeader} - queryParams: - api-key: ${API_KEY} - authServices: - my-google-service: - kind: google - clientId: ${clientId} - other-google-service: - kind: google - clientId: ${clientId2} - - tools: - example_tool: - kind: http - source: my-instance - method: GET - path: "search?name=alice&pet=${cat_string}" - description: some description - authRequired: - - my-google-auth-service - - other-auth-service - queryParams: - - name: country - type: string - description: some description - authServices: - - name: my-google-auth-service - field: user_id - - name: other-auth-service - field: user_id - requestBody: | - { - "age": {{.age}}, - "city": "{{.city}}", - "food": "${food_string}", - "other": "$OTHER" - } - bodyParams: - - name: age - type: integer - description: age num - - name: city - type: string - description: city string - headers: - Authorization: API_KEY - Content-Type: application/json - headerParams: - - name: Language - type: string - description: language string - - toolsets: - ${toolset_name}: - - example_tool - - - prompts: - ${prompt_name}: - description: A test prompt for {{.name}}. - messages: - - role: user - content: ${prompt_content} - `, - wantToolsFile: ToolsFile{ - Sources: server.SourceConfigs{ - "my-http-instance": httpsrc.Config{ - Name: "my-http-instance", - Type: httpsrc.SourceType, - BaseURL: "http://test_server/", - Timeout: "10s", - DefaultHeaders: map[string]string{"Authorization": "ACTUAL_HEADER"}, - QueryParams: map[string]string{"api-key": "ACTUAL_API_KEY"}, - }, - }, - AuthServices: server.AuthServiceConfigs{ - "my-google-service": google.Config{ - Name: "my-google-service", - Type: google.AuthServiceType, - ClientID: "ACTUAL_CLIENT_ID", - }, - "other-google-service": google.Config{ - Name: "other-google-service", - Type: google.AuthServiceType, - ClientID: "ACTUAL_CLIENT_ID_2", - }, - }, - Tools: server.ToolConfigs{ - "example_tool": http.Config{ - Name: "example_tool", - Type: "http", - Source: "my-instance", - Method: "GET", - Path: "search?name=alice&pet=cat", - Description: "some description", - AuthRequired: []string{"my-google-auth-service", "other-auth-service"}, - QueryParams: []parameters.Parameter{ - parameters.NewStringParameterWithAuth("country", "some description", - []parameters.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"}, - {Name: "other-auth-service", Field: "user_id"}}), - }, - RequestBody: `{ - "age": {{.age}}, - "city": "{{.city}}", - "food": "food", - "other": "$OTHER" -} -`, - BodyParams: []parameters.Parameter{parameters.NewIntParameter("age", "age num"), parameters.NewStringParameter("city", "city string")}, - Headers: map[string]string{"Authorization": "API_KEY", "Content-Type": "application/json"}, - HeaderParams: []parameters.Parameter{parameters.NewStringParameter("Language", "language string")}, - }, - }, - Toolsets: server.ToolsetConfigs{ - "ACTUAL_TOOLSET_NAME": tools.ToolsetConfig{ - Name: "ACTUAL_TOOLSET_NAME", - ToolNames: []string{"example_tool"}, - }, - }, - Prompts: server.PromptConfigs{ - "ACTUAL_PROMPT_NAME": &custom.Config{ - Name: "ACTUAL_PROMPT_NAME", - Description: "A test prompt for {{.name}}.", - Messages: []prompts.Message{ - { - Role: "user", - Content: "ACTUAL_CONTENT", - }, - }, - Arguments: nil, - }, - }, - }, - }, - { - description: "file with env var example toolsfile v2", - in: ` - kind: sources - name: my-http-instance - type: http - baseUrl: http://test_server/ - timeout: 10s - headers: - Authorization: ${TestHeader} - queryParams: - api-key: ${API_KEY} ---- - kind: authServices - name: my-google-service - type: google - clientId: ${clientId} ---- - kind: authServices - name: other-google-service - type: google - clientId: ${clientId2} ---- - kind: tools - name: example_tool - type: http - source: my-instance - method: GET - path: "search?name=alice&pet=${cat_string}" - description: some description - authRequired: - - my-google-auth-service - - other-auth-service - queryParams: - - name: country - type: string - description: some description - authServices: - - name: my-google-auth-service - field: user_id - - name: other-auth-service - field: user_id - requestBody: | - { - "age": {{.age}}, - "city": "{{.city}}", - "food": "${food_string}", - "other": "$OTHER" - } - bodyParams: - - name: age - type: integer - description: age num - - name: city - type: string - description: city string - headers: - Authorization: API_KEY - Content-Type: application/json - headerParams: - - name: Language - type: string - description: language string ---- - kind: toolsets - name: ${toolset_name} - tools: - - example_tool ---- - kind: prompts - name: ${prompt_name} - description: A test prompt for {{.name}}. - messages: - - role: user - content: ${prompt_content} - `, - wantToolsFile: ToolsFile{ - Sources: server.SourceConfigs{ - "my-http-instance": httpsrc.Config{ - Name: "my-http-instance", - Type: httpsrc.SourceType, - BaseURL: "http://test_server/", - Timeout: "10s", - DefaultHeaders: map[string]string{"Authorization": "ACTUAL_HEADER"}, - QueryParams: map[string]string{"api-key": "ACTUAL_API_KEY"}, - }, - }, - AuthServices: server.AuthServiceConfigs{ - "my-google-service": google.Config{ - Name: "my-google-service", - Type: google.AuthServiceType, - ClientID: "ACTUAL_CLIENT_ID", - }, - "other-google-service": google.Config{ - Name: "other-google-service", - Type: google.AuthServiceType, - ClientID: "ACTUAL_CLIENT_ID_2", - }, - }, - Tools: server.ToolConfigs{ - "example_tool": http.Config{ - Name: "example_tool", - Type: "http", - Source: "my-instance", - Method: "GET", - Path: "search?name=alice&pet=cat", - Description: "some description", - AuthRequired: []string{"my-google-auth-service", "other-auth-service"}, - QueryParams: []parameters.Parameter{ - parameters.NewStringParameterWithAuth("country", "some description", - []parameters.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"}, - {Name: "other-auth-service", Field: "user_id"}}), - }, - RequestBody: `{ - "age": {{.age}}, - "city": "{{.city}}", - "food": "food", - "other": "$OTHER" -} -`, - BodyParams: []parameters.Parameter{parameters.NewIntParameter("age", "age num"), parameters.NewStringParameter("city", "city string")}, - Headers: map[string]string{"Authorization": "API_KEY", "Content-Type": "application/json"}, - HeaderParams: []parameters.Parameter{parameters.NewStringParameter("Language", "language string")}, - }, - }, - Toolsets: server.ToolsetConfigs{ - "ACTUAL_TOOLSET_NAME": tools.ToolsetConfig{ - Name: "ACTUAL_TOOLSET_NAME", - ToolNames: []string{"example_tool"}, - }, - }, - Prompts: server.PromptConfigs{ - "ACTUAL_PROMPT_NAME": &custom.Config{ - Name: "ACTUAL_PROMPT_NAME", - Description: "A test prompt for {{.name}}.", - Messages: []prompts.Message{ - { - Role: "user", - Content: "ACTUAL_CONTENT", - }, - }, - Arguments: nil, - }, - }, - }, - }, - } - for _, tc := range tcs { - t.Run(tc.description, func(t *testing.T) { - toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in)) - if err != nil { - t.Fatalf("failed to parse input: %v", err) - } - if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" { - t.Fatalf("incorrect sources parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" { - t.Fatalf("incorrect authServices parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" { - t.Fatalf("incorrect tools parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.Toolsets, toolsFile.Toolsets); diff != "" { - t.Fatalf("incorrect toolsets parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.Prompts, toolsFile.Prompts); diff != "" { - t.Fatalf("incorrect prompts parse: diff %v", diff) - } - }) - } -} - // normalizeFilepaths is a helper function to allow same filepath formats for Mac and Windows. // this prevents needing multiple "want" cases for TestResolveWatcherInputs func normalizeFilepaths(m map[string]bool) map[string]bool { @@ -2052,485 +618,6 @@ func TestSingleEdit(t *testing.T) { } } -func TestPrebuiltTools(t *testing.T) { - // Get prebuilt configs - alloydb_omni_config, _ := prebuiltconfigs.Get("alloydb-omni") - alloydb_admin_config, _ := prebuiltconfigs.Get("alloydb-postgres-admin") - alloydb_config, _ := prebuiltconfigs.Get("alloydb-postgres") - bigquery_config, _ := prebuiltconfigs.Get("bigquery") - clickhouse_config, _ := prebuiltconfigs.Get("clickhouse") - cloudsqlpg_config, _ := prebuiltconfigs.Get("cloud-sql-postgres") - cloudsqlpg_admin_config, _ := prebuiltconfigs.Get("cloud-sql-postgres-admin") - cloudsqlmysql_config, _ := prebuiltconfigs.Get("cloud-sql-mysql") - cloudsqlmysql_admin_config, _ := prebuiltconfigs.Get("cloud-sql-mysql-admin") - cloudsqlmssql_config, _ := prebuiltconfigs.Get("cloud-sql-mssql") - cloudsqlmssql_admin_config, _ := prebuiltconfigs.Get("cloud-sql-mssql-admin") - dataplex_config, _ := prebuiltconfigs.Get("dataplex") - firestoreconfig, _ := prebuiltconfigs.Get("firestore") - mysql_config, _ := prebuiltconfigs.Get("mysql") - mssql_config, _ := prebuiltconfigs.Get("mssql") - looker_config, _ := prebuiltconfigs.Get("looker") - lookerca_config, _ := prebuiltconfigs.Get("looker-conversational-analytics") - postgresconfig, _ := prebuiltconfigs.Get("postgres") - spanner_config, _ := prebuiltconfigs.Get("spanner") - spannerpg_config, _ := prebuiltconfigs.Get("spanner-postgres") - mindsdb_config, _ := prebuiltconfigs.Get("mindsdb") - sqlite_config, _ := prebuiltconfigs.Get("sqlite") - neo4jconfig, _ := prebuiltconfigs.Get("neo4j") - alloydbobsvconfig, _ := prebuiltconfigs.Get("alloydb-postgres-observability") - cloudsqlpgobsvconfig, _ := prebuiltconfigs.Get("cloud-sql-postgres-observability") - cloudsqlmysqlobsvconfig, _ := prebuiltconfigs.Get("cloud-sql-mysql-observability") - cloudsqlmssqlobsvconfig, _ := prebuiltconfigs.Get("cloud-sql-mssql-observability") - serverless_spark_config, _ := prebuiltconfigs.Get("serverless-spark") - cloudhealthcare_config, _ := prebuiltconfigs.Get("cloud-healthcare") - snowflake_config, _ := prebuiltconfigs.Get("snowflake") - - // Set environment variables - t.Setenv("API_KEY", "your_api_key") - - t.Setenv("BIGQUERY_PROJECT", "your_gcp_project_id") - t.Setenv("DATAPLEX_PROJECT", "your_gcp_project_id") - t.Setenv("FIRESTORE_PROJECT", "your_gcp_project_id") - t.Setenv("FIRESTORE_DATABASE", "your_firestore_db_name") - - t.Setenv("SPANNER_PROJECT", "your_gcp_project_id") - t.Setenv("SPANNER_INSTANCE", "your_spanner_instance") - t.Setenv("SPANNER_DATABASE", "your_spanner_db") - - t.Setenv("ALLOYDB_POSTGRES_PROJECT", "your_gcp_project_id") - t.Setenv("ALLOYDB_POSTGRES_REGION", "your_gcp_region") - t.Setenv("ALLOYDB_POSTGRES_CLUSTER", "your_alloydb_cluster") - t.Setenv("ALLOYDB_POSTGRES_INSTANCE", "your_alloydb_instance") - t.Setenv("ALLOYDB_POSTGRES_DATABASE", "your_alloydb_db") - t.Setenv("ALLOYDB_POSTGRES_USER", "your_alloydb_user") - t.Setenv("ALLOYDB_POSTGRES_PASSWORD", "your_alloydb_password") - - t.Setenv("ALLOYDB_OMNI_HOST", "localhost") - t.Setenv("ALLOYDB_OMNI_PORT", "5432") - t.Setenv("ALLOYDB_OMNI_DATABASE", "your_alloydb_db") - t.Setenv("ALLOYDB_OMNI_USER", "your_alloydb_user") - t.Setenv("ALLOYDB_OMNI_PASSWORD", "your_alloydb_password") - - t.Setenv("CLICKHOUSE_PROTOCOL", "your_clickhouse_protocol") - t.Setenv("CLICKHOUSE_DATABASE", "your_clickhouse_database") - t.Setenv("CLICKHOUSE_PASSWORD", "your_clickhouse_password") - t.Setenv("CLICKHOUSE_USER", "your_clickhouse_user") - t.Setenv("CLICKHOUSE_HOST", "your_clickhosue_host") - t.Setenv("CLICKHOUSE_PORT", "8123") - - t.Setenv("CLOUD_SQL_POSTGRES_PROJECT", "your_pg_project") - t.Setenv("CLOUD_SQL_POSTGRES_INSTANCE", "your_pg_instance") - t.Setenv("CLOUD_SQL_POSTGRES_DATABASE", "your_pg_db") - t.Setenv("CLOUD_SQL_POSTGRES_REGION", "your_pg_region") - t.Setenv("CLOUD_SQL_POSTGRES_USER", "your_pg_user") - t.Setenv("CLOUD_SQL_POSTGRES_PASS", "your_pg_pass") - - t.Setenv("CLOUD_SQL_MYSQL_PROJECT", "your_gcp_project_id") - t.Setenv("CLOUD_SQL_MYSQL_REGION", "your_gcp_region") - t.Setenv("CLOUD_SQL_MYSQL_INSTANCE", "your_instance") - t.Setenv("CLOUD_SQL_MYSQL_DATABASE", "your_cloudsql_mysql_db") - t.Setenv("CLOUD_SQL_MYSQL_USER", "your_cloudsql_mysql_user") - t.Setenv("CLOUD_SQL_MYSQL_PASSWORD", "your_cloudsql_mysql_password") - - t.Setenv("CLOUD_SQL_MSSQL_PROJECT", "your_gcp_project_id") - t.Setenv("CLOUD_SQL_MSSQL_REGION", "your_gcp_region") - t.Setenv("CLOUD_SQL_MSSQL_INSTANCE", "your_cloudsql_mssql_instance") - t.Setenv("CLOUD_SQL_MSSQL_DATABASE", "your_cloudsql_mssql_db") - t.Setenv("CLOUD_SQL_MSSQL_IP_ADDRESS", "127.0.0.1") - t.Setenv("CLOUD_SQL_MSSQL_USER", "your_cloudsql_mssql_user") - t.Setenv("CLOUD_SQL_MSSQL_PASSWORD", "your_cloudsql_mssql_password") - t.Setenv("CLOUD_SQL_POSTGRES_PASSWORD", "your_cloudsql_pg_password") - - t.Setenv("SERVERLESS_SPARK_PROJECT", "your_gcp_project_id") - t.Setenv("SERVERLESS_SPARK_LOCATION", "your_gcp_location") - - t.Setenv("POSTGRES_HOST", "localhost") - t.Setenv("POSTGRES_PORT", "5432") - t.Setenv("POSTGRES_DATABASE", "your_postgres_db") - t.Setenv("POSTGRES_USER", "your_postgres_user") - t.Setenv("POSTGRES_PASSWORD", "your_postgres_password") - - t.Setenv("MYSQL_HOST", "localhost") - t.Setenv("MYSQL_PORT", "3306") - t.Setenv("MYSQL_DATABASE", "your_mysql_db") - t.Setenv("MYSQL_USER", "your_mysql_user") - t.Setenv("MYSQL_PASSWORD", "your_mysql_password") - - t.Setenv("MSSQL_HOST", "localhost") - t.Setenv("MSSQL_PORT", "1433") - t.Setenv("MSSQL_DATABASE", "your_mssql_db") - t.Setenv("MSSQL_USER", "your_mssql_user") - t.Setenv("MSSQL_PASSWORD", "your_mssql_password") - - t.Setenv("MINDSDB_HOST", "localhost") - t.Setenv("MINDSDB_PORT", "47334") - t.Setenv("MINDSDB_DATABASE", "your_mindsdb_db") - t.Setenv("MINDSDB_USER", "your_mindsdb_user") - t.Setenv("MINDSDB_PASS", "your_mindsdb_password") - - t.Setenv("LOOKER_BASE_URL", "https://your_company.looker.com") - t.Setenv("LOOKER_CLIENT_ID", "your_looker_client_id") - t.Setenv("LOOKER_CLIENT_SECRET", "your_looker_client_secret") - t.Setenv("LOOKER_VERIFY_SSL", "true") - - t.Setenv("LOOKER_PROJECT", "your_project_id") - t.Setenv("LOOKER_LOCATION", "us") - - t.Setenv("SQLITE_DATABASE", "test.db") - - t.Setenv("NEO4J_URI", "bolt://localhost:7687") - t.Setenv("NEO4J_DATABASE", "neo4j") - t.Setenv("NEO4J_USERNAME", "your_neo4j_user") - t.Setenv("NEO4J_PASSWORD", "your_neo4j_password") - - t.Setenv("CLOUD_HEALTHCARE_PROJECT", "your_gcp_project_id") - t.Setenv("CLOUD_HEALTHCARE_REGION", "your_gcp_region") - t.Setenv("CLOUD_HEALTHCARE_DATASET", "your_healthcare_dataset") - - t.Setenv("SNOWFLAKE_ACCOUNT", "your_account") - t.Setenv("SNOWFLAKE_USER", "your_username") - t.Setenv("SNOWFLAKE_PASSWORD", "your_pass") - t.Setenv("SNOWFLAKE_DATABASE", "your_db") - t.Setenv("SNOWFLAKE_SCHEMA", "your_schema") - t.Setenv("SNOWFLAKE_WAREHOUSE", "your_wh") - t.Setenv("SNOWFLAKE_ROLE", "your_role") - - ctx, err := testutils.ContextWithNewLogger() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - tcs := []struct { - name string - in []byte - wantToolset server.ToolsetConfigs - }{ - { - name: "alloydb omni prebuilt tools", - in: alloydb_omni_config, - wantToolset: server.ToolsetConfigs{ - "alloydb_omni_database_tools": tools.ToolsetConfig{ - Name: "alloydb_omni_database_tools", - ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_columnar_configurations", "list_columnar_recommended_columns", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"}, - }, - }, - }, - { - name: "alloydb postgres admin prebuilt tools", - in: alloydb_admin_config, - wantToolset: server.ToolsetConfigs{ - "alloydb_postgres_admin_tools": tools.ToolsetConfig{ - Name: "alloydb_postgres_admin_tools", - ToolNames: []string{"create_cluster", "wait_for_operation", "create_instance", "list_clusters", "list_instances", "list_users", "create_user", "get_cluster", "get_instance", "get_user"}, - }, - }, - }, - { - name: "cloudsql pg admin prebuilt tools", - in: cloudsqlpg_admin_config, - wantToolset: server.ToolsetConfigs{ - "cloud_sql_postgres_admin_tools": tools.ToolsetConfig{ - Name: "cloud_sql_postgres_admin_tools", - ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance", "create_backup", "restore_backup"}, - }, - }, - }, - { - name: "cloudsql mysql admin prebuilt tools", - in: cloudsqlmysql_admin_config, - wantToolset: server.ToolsetConfigs{ - "cloud_sql_mysql_admin_tools": tools.ToolsetConfig{ - Name: "cloud_sql_mysql_admin_tools", - ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup", "restore_backup"}, - }, - }, - }, - { - name: "cloudsql mssql admin prebuilt tools", - in: cloudsqlmssql_admin_config, - wantToolset: server.ToolsetConfigs{ - "cloud_sql_mssql_admin_tools": tools.ToolsetConfig{ - Name: "cloud_sql_mssql_admin_tools", - ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup", "restore_backup"}, - }, - }, - }, - { - name: "alloydb prebuilt tools", - in: alloydb_config, - wantToolset: server.ToolsetConfigs{ - "alloydb_postgres_database_tools": tools.ToolsetConfig{ - Name: "alloydb_postgres_database_tools", - ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"}, - }, - }, - }, - { - name: "bigquery prebuilt tools", - in: bigquery_config, - wantToolset: server.ToolsetConfigs{ - "bigquery_database_tools": tools.ToolsetConfig{ - Name: "bigquery_database_tools", - ToolNames: []string{"analyze_contribution", "ask_data_insights", "execute_sql", "forecast", "get_dataset_info", "get_table_info", "list_dataset_ids", "list_table_ids", "search_catalog"}, - }, - }, - }, - { - name: "clickhouse prebuilt tools", - in: clickhouse_config, - wantToolset: server.ToolsetConfigs{ - "clickhouse_database_tools": tools.ToolsetConfig{ - Name: "clickhouse_database_tools", - ToolNames: []string{"execute_sql", "list_databases", "list_tables"}, - }, - }, - }, - { - name: "cloudsqlpg prebuilt tools", - in: cloudsqlpg_config, - wantToolset: server.ToolsetConfigs{ - "cloud_sql_postgres_database_tools": tools.ToolsetConfig{ - Name: "cloud_sql_postgres_database_tools", - ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"}, - }, - }, - }, - { - name: "cloudsqlmysql prebuilt tools", - in: cloudsqlmysql_config, - wantToolset: server.ToolsetConfigs{ - "cloud_sql_mysql_database_tools": tools.ToolsetConfig{ - Name: "cloud_sql_mysql_database_tools", - ToolNames: []string{"execute_sql", "list_tables", "get_query_plan", "list_active_queries", "list_tables_missing_unique_indexes", "list_table_fragmentation"}, - }, - }, - }, - { - name: "cloudsqlmssql prebuilt tools", - in: cloudsqlmssql_config, - wantToolset: server.ToolsetConfigs{ - "cloud_sql_mssql_database_tools": tools.ToolsetConfig{ - Name: "cloud_sql_mssql_database_tools", - ToolNames: []string{"execute_sql", "list_tables"}, - }, - }, - }, - { - name: "dataplex prebuilt tools", - in: dataplex_config, - wantToolset: server.ToolsetConfigs{ - "dataplex_tools": tools.ToolsetConfig{ - Name: "dataplex_tools", - ToolNames: []string{"search_entries", "lookup_entry", "search_aspect_types"}, - }, - }, - }, - { - name: "serverless spark prebuilt tools", - in: serverless_spark_config, - wantToolset: server.ToolsetConfigs{ - "serverless_spark_tools": tools.ToolsetConfig{ - Name: "serverless_spark_tools", - ToolNames: []string{"list_batches", "get_batch", "cancel_batch", "create_pyspark_batch", "create_spark_batch"}, - }, - }, - }, - { - name: "firestore prebuilt tools", - in: firestoreconfig, - wantToolset: server.ToolsetConfigs{ - "firestore_database_tools": tools.ToolsetConfig{ - Name: "firestore_database_tools", - ToolNames: []string{"get_documents", "add_documents", "update_document", "list_collections", "delete_documents", "query_collection", "get_rules", "validate_rules"}, - }, - }, - }, - { - name: "mysql prebuilt tools", - in: mysql_config, - wantToolset: server.ToolsetConfigs{ - "mysql_database_tools": tools.ToolsetConfig{ - Name: "mysql_database_tools", - ToolNames: []string{"execute_sql", "list_tables", "get_query_plan", "list_active_queries", "list_tables_missing_unique_indexes", "list_table_fragmentation"}, - }, - }, - }, - { - name: "mssql prebuilt tools", - in: mssql_config, - wantToolset: server.ToolsetConfigs{ - "mssql_database_tools": tools.ToolsetConfig{ - Name: "mssql_database_tools", - ToolNames: []string{"execute_sql", "list_tables"}, - }, - }, - }, - { - name: "looker prebuilt tools", - in: looker_config, - wantToolset: server.ToolsetConfigs{ - "looker_tools": tools.ToolsetConfig{ - Name: "looker_tools", - ToolNames: []string{"get_models", "get_explores", "get_dimensions", "get_measures", "get_filters", "get_parameters", "query", "query_sql", "query_url", "get_looks", "run_look", "make_look", "get_dashboards", "run_dashboard", "make_dashboard", "add_dashboard_element", "add_dashboard_filter", "generate_embed_url", "health_pulse", "health_analyze", "health_vacuum", "dev_mode", "get_projects", "get_project_files", "get_project_file", "create_project_file", "update_project_file", "delete_project_file", "validate_project", "get_connections", "get_connection_schemas", "get_connection_databases", "get_connection_tables", "get_connection_table_columns"}, - }, - }, - }, - { - name: "looker-conversational-analytics prebuilt tools", - in: lookerca_config, - wantToolset: server.ToolsetConfigs{ - "looker_conversational_analytics_tools": tools.ToolsetConfig{ - Name: "looker_conversational_analytics_tools", - ToolNames: []string{"ask_data_insights", "get_models", "get_explores"}, - }, - }, - }, - { - name: "postgres prebuilt tools", - in: postgresconfig, - wantToolset: server.ToolsetConfigs{ - "postgres_database_tools": tools.ToolsetConfig{ - Name: "postgres_database_tools", - ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"}, - }, - }, - }, - { - name: "spanner prebuilt tools", - in: spanner_config, - wantToolset: server.ToolsetConfigs{ - "spanner-database-tools": tools.ToolsetConfig{ - Name: "spanner-database-tools", - ToolNames: []string{"execute_sql", "execute_sql_dql", "list_tables", "list_graphs"}, - }, - }, - }, - { - name: "spanner pg prebuilt tools", - in: spannerpg_config, - wantToolset: server.ToolsetConfigs{ - "spanner_postgres_database_tools": tools.ToolsetConfig{ - Name: "spanner_postgres_database_tools", - ToolNames: []string{"execute_sql", "execute_sql_dql", "list_tables"}, - }, - }, - }, - { - name: "mindsdb prebuilt tools", - in: mindsdb_config, - wantToolset: server.ToolsetConfigs{ - "mindsdb-tools": tools.ToolsetConfig{ - Name: "mindsdb-tools", - ToolNames: []string{"mindsdb-execute-sql", "mindsdb-sql"}, - }, - }, - }, - { - name: "sqlite prebuilt tools", - in: sqlite_config, - wantToolset: server.ToolsetConfigs{ - "sqlite_database_tools": tools.ToolsetConfig{ - Name: "sqlite_database_tools", - ToolNames: []string{"execute_sql", "list_tables"}, - }, - }, - }, - { - name: "neo4j prebuilt tools", - in: neo4jconfig, - wantToolset: server.ToolsetConfigs{ - "neo4j_database_tools": tools.ToolsetConfig{ - Name: "neo4j_database_tools", - ToolNames: []string{"execute_cypher", "get_schema"}, - }, - }, - }, - { - name: "alloydb postgres observability prebuilt tools", - in: alloydbobsvconfig, - wantToolset: server.ToolsetConfigs{ - "alloydb_postgres_cloud_monitoring_tools": tools.ToolsetConfig{ - Name: "alloydb_postgres_cloud_monitoring_tools", - ToolNames: []string{"get_system_metrics", "get_query_metrics"}, - }, - }, - }, - { - name: "cloudsql postgres observability prebuilt tools", - in: cloudsqlpgobsvconfig, - wantToolset: server.ToolsetConfigs{ - "cloud_sql_postgres_cloud_monitoring_tools": tools.ToolsetConfig{ - Name: "cloud_sql_postgres_cloud_monitoring_tools", - ToolNames: []string{"get_system_metrics", "get_query_metrics"}, - }, - }, - }, - { - name: "cloudsql mysql observability prebuilt tools", - in: cloudsqlmysqlobsvconfig, - wantToolset: server.ToolsetConfigs{ - "cloud_sql_mysql_cloud_monitoring_tools": tools.ToolsetConfig{ - Name: "cloud_sql_mysql_cloud_monitoring_tools", - ToolNames: []string{"get_system_metrics", "get_query_metrics"}, - }, - }, - }, - { - name: "cloudsql mssql observability prebuilt tools", - in: cloudsqlmssqlobsvconfig, - wantToolset: server.ToolsetConfigs{ - "cloud_sql_mssql_cloud_monitoring_tools": tools.ToolsetConfig{ - Name: "cloud_sql_mssql_cloud_monitoring_tools", - ToolNames: []string{"get_system_metrics"}, - }, - }, - }, - { - name: "cloud healthcare prebuilt tools", - in: cloudhealthcare_config, - wantToolset: server.ToolsetConfigs{ - "cloud_healthcare_dataset_tools": tools.ToolsetConfig{ - Name: "cloud_healthcare_dataset_tools", - ToolNames: []string{"get_dataset", "list_dicom_stores", "list_fhir_stores"}, - }, - "cloud_healthcare_fhir_tools": tools.ToolsetConfig{ - Name: "cloud_healthcare_fhir_tools", - ToolNames: []string{"get_fhir_store", "get_fhir_store_metrics", "get_fhir_resource", "fhir_patient_search", "fhir_patient_everything", "fhir_fetch_page"}, - }, - "cloud_healthcare_dicom_tools": tools.ToolsetConfig{ - Name: "cloud_healthcare_dicom_tools", - ToolNames: []string{"get_dicom_store", "get_dicom_store_metrics", "search_dicom_studies", "search_dicom_series", "search_dicom_instances", "retrieve_rendered_dicom_instance"}, - }, - }, - }, - { - name: "Snowflake prebuilt tool", - in: snowflake_config, - wantToolset: server.ToolsetConfigs{ - "snowflake_tools": tools.ToolsetConfig{ - Name: "snowflake_tools", - ToolNames: []string{"execute_sql", "list_tables"}, - }, - }, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - toolsFile, err := parseToolsFile(ctx, tc.in) - if err != nil { - t.Fatalf("failed to parse input: %v", err) - } - if diff := cmp.Diff(tc.wantToolset, toolsFile.Toolsets); diff != "" { - t.Fatalf("incorrect tools parse: diff %v", diff) - } - // Prebuilt configs do not have prompts, so assert empty maps. - if len(toolsFile.Prompts) != 0 { - t.Fatalf("expected empty prompts map for prebuilt config, got: %v", toolsFile.Prompts) - } - }) - } -} - func TestMutuallyExclusiveFlags(t *testing.T) { testCases := []struct { desc string @@ -2551,7 +638,9 @@ func TestMutuallyExclusiveFlags(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - cmd := NewCommand() + buf := new(bytes.Buffer) + opts := internal.NewToolboxOptions(internal.WithIOStreams(buf, buf)) + cmd := NewCommand(opts) cmd.SetArgs(tc.args) err := cmd.Execute() if err == nil { @@ -2566,7 +655,9 @@ func TestMutuallyExclusiveFlags(t *testing.T) { func TestFileLoadingErrors(t *testing.T) { t.Run("non-existent tools-file", func(t *testing.T) { - cmd := NewCommand() + buf := new(bytes.Buffer) + opts := internal.NewToolboxOptions(internal.WithIOStreams(buf, buf)) + cmd := NewCommand(opts) // Use a file that is guaranteed not to exist nonExistentFile := filepath.Join(t.TempDir(), "non-existent-tools.yaml") cmd.SetArgs([]string{"--tools-file", nonExistentFile}) @@ -2581,7 +672,9 @@ func TestFileLoadingErrors(t *testing.T) { }) t.Run("non-existent tools-folder", func(t *testing.T) { - cmd := NewCommand() + buf := new(bytes.Buffer) + opts := internal.NewToolboxOptions(internal.WithIOStreams(buf, buf)) + cmd := NewCommand(opts) nonExistentFolder := filepath.Join(t.TempDir(), "non-existent-folder") cmd.SetArgs([]string{"--tools-folder", nonExistentFolder}) @@ -2595,94 +688,6 @@ func TestFileLoadingErrors(t *testing.T) { }) } -func TestMergeToolsFiles(t *testing.T) { - file1 := ToolsFile{ - Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}}, - Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}}, - Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}}, - EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}}, - } - file2 := ToolsFile{ - AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}}, - Tools: server.ToolConfigs{"tool2": http.Config{Name: "tool2"}}, - Toolsets: server.ToolsetConfigs{"set2": tools.ToolsetConfig{Name: "set2"}}, - } - fileWithConflicts := ToolsFile{ - Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}}, - Tools: server.ToolConfigs{"tool2": http.Config{Name: "tool2"}}, - } - - testCases := []struct { - name string - files []ToolsFile - want ToolsFile - wantErr bool - }{ - { - name: "merge two distinct files", - files: []ToolsFile{file1, file2}, - want: ToolsFile{ - Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}}, - AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}}, - Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}, "tool2": http.Config{Name: "tool2"}}, - Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}, "set2": tools.ToolsetConfig{Name: "set2"}}, - Prompts: server.PromptConfigs{}, - EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}}, - }, - wantErr: false, - }, - { - name: "merge with conflicts", - files: []ToolsFile{file1, file2, fileWithConflicts}, - wantErr: true, - }, - { - name: "merge single file", - files: []ToolsFile{file1}, - want: ToolsFile{ - Sources: file1.Sources, - AuthServices: make(server.AuthServiceConfigs), - EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}}, - Tools: file1.Tools, - Toolsets: file1.Toolsets, - Prompts: server.PromptConfigs{}, - }, - }, - { - name: "merge empty list", - files: []ToolsFile{}, - want: ToolsFile{ - Sources: make(server.SourceConfigs), - AuthServices: make(server.AuthServiceConfigs), - EmbeddingModels: make(server.EmbeddingModelConfigs), - Tools: make(server.ToolConfigs), - Toolsets: make(server.ToolsetConfigs), - Prompts: server.PromptConfigs{}, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - got, err := mergeToolsFiles(tc.files...) - if (err != nil) != tc.wantErr { - t.Fatalf("mergeToolsFiles() error = %v, wantErr %v", err, tc.wantErr) - } - if !tc.wantErr { - if diff := cmp.Diff(tc.want, got); diff != "" { - t.Errorf("mergeToolsFiles() mismatch (-want +got):\n%s", diff) - } - } else { - if err == nil { - t.Fatal("expected an error for conflicting files but got none") - } - if !strings.Contains(err.Error(), "resource conflicts detected") { - t.Errorf("expected conflict error, but got: %v", err) - } - } - }) - } -} func TestPrebuiltAndCustomTools(t *testing.T) { t.Setenv("SQLITE_DATABASE", "test.db") // Setup custom tools file @@ -2848,7 +853,7 @@ authSources: ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() - cmd, output, err := invokeCommandWithContext(ctx, tc.args) + _, opts, output, err := invokeCommandWithContext(ctx, tc.args) if tc.wantErr { if err == nil { @@ -2865,7 +870,7 @@ authSources: t.Errorf("server did not start successfully (no ready message found). Output:\n%s", output) } if tc.cfgCheck != nil { - if err := tc.cfgCheck(cmd.cfg); err != nil { + if err := tc.cfgCheck(opts.Cfg); err != nil { t.Errorf("config check failed: %v", err) } } @@ -2899,7 +904,7 @@ func TestDefaultToolsFileBehavior(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() - _, output, err := invokeCommandWithContext(ctx, tc.args) + _, _, output, err := invokeCommandWithContext(ctx, tc.args) if tc.expectRun { if err != nil && err != context.DeadlineExceeded && err != context.Canceled { @@ -2921,114 +926,29 @@ func TestDefaultToolsFileBehavior(t *testing.T) { } } -func TestParameterReferenceValidation(t *testing.T) { - ctx, err := testutils.ContextWithNewLogger() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } +func TestSubcommandWiring(t *testing.T) { + buf := new(bytes.Buffer) + opts := internal.NewToolboxOptions(internal.WithIOStreams(buf, buf)) + baseCmd := NewCommand(opts) - // Base template - baseYaml := ` -sources: - dummy-source: - kind: http - baseUrl: http://example.com -tools: - test-tool: - kind: postgres-sql - source: dummy-source - description: test tool - statement: SELECT 1; - parameters: -%s` - - tcs := []struct { - desc string - params string - wantErr bool - errSubstr string + tests := []struct { + args []string + expectedName string }{ - { - desc: "valid backward reference", - params: ` - - name: source_param - type: string - description: source - - name: copy_param - type: string - description: copy - valueFromParam: source_param`, - wantErr: false, - }, - { - desc: "valid forward reference (out of order)", - params: ` - - name: copy_param - type: string - description: copy - valueFromParam: source_param - - name: source_param - type: string - description: source`, - wantErr: false, - }, - { - desc: "invalid missing reference", - params: ` - - name: copy_param - type: string - description: copy - valueFromParam: non_existent_param`, - wantErr: true, - errSubstr: "references '\"non_existent_param\"' in the 'valueFromParam' field", - }, - { - desc: "invalid self reference", - params: ` - - name: myself - type: string - description: self - valueFromParam: myself`, - wantErr: true, - errSubstr: "parameter \"myself\" cannot copy value from itself", - }, - { - desc: "multiple valid references", - params: ` - - name: a - type: string - description: a - - name: b - type: string - description: b - valueFromParam: a - - name: c - type: string - description: c - valueFromParam: a`, - wantErr: false, - }, + {[]string{"invoke"}, "invoke"}, + {[]string{"skills-generate"}, "skills-generate"}, } - for _, tc := range tcs { - t.Run(tc.desc, func(t *testing.T) { - // Indent parameters to match YAML structure - yamlContent := fmt.Sprintf(baseYaml, tc.params) + for _, tc := range tests { + // Find returns the Command struct and the remaining args + cmd, _, err := baseCmd.Find(tc.args) - _, err := parseToolsFile(ctx, []byte(yamlContent)) + if err != nil { + t.Fatalf("Failed to find command %v: %v", tc.args, err) + } - if tc.wantErr { - if err == nil { - t.Fatal("expected error, got nil") - } - if !strings.Contains(err.Error(), tc.errSubstr) { - t.Errorf("error %q does not contain expected substring %q", err.Error(), tc.errSubstr) - } - } else { - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - } - }) + if cmd.Name() != tc.expectedName { + t.Errorf("Expected command name %q, got %q", tc.expectedName, cmd.Name()) + } } } diff --git a/tests/server.go b/tests/server.go index eac693f3a2..ac8ecb13d4 100644 --- a/tests/server.go +++ b/tests/server.go @@ -21,6 +21,7 @@ import ( "os" yaml "github.com/goccy/go-yaml" + "github.com/spf13/cobra" "github.com/googleapis/genai-toolbox/cmd" ) @@ -50,7 +51,7 @@ func tmpFileWithCleanup(content []byte) (string, func(), error) { type CmdExec struct { Out io.ReadCloser - cmd *cmd.Command + cmd *cobra.Command cancel context.CancelFunc closers []io.Closer done chan bool // closed once the cmd is completed @@ -77,7 +78,7 @@ func StartCmd(ctx context.Context, toolsFile map[string]any, args ...string) (*C return nil, nil, fmt.Errorf("unable to open stdout pipe: %w", err) } - c := cmd.NewCommand(cmd.WithStreams(pw, pw)) + c := cmd.GenerateCommand(pw, pw) c.SetArgs(args) t := &CmdExec{