From 10d7397c0137e05c28c6f91945cd0f6e09b2e22f Mon Sep 17 00:00:00 2001 From: duwenxin99 Date: Wed, 4 Feb 2026 11:22:32 -0500 Subject: [PATCH] all refactor done --- .../elasticsearchesql/elasticsearchesql.go | 14 +- .../lookerdeleteprojectfile.go | 14 +- .../looker/lookerdevmode/lookerdevmode.go | 16 +- .../lookergenerateembedurl.go | 13 +- .../lookergetconnectiondatabases.go | 13 +- .../lookergetconnections.go | 13 +- .../lookergetconnectionschemas.go | 12 +- .../lookergetconnectiontablecolumns.go | 19 +- .../lookergetconnectiontables.go | 15 +- .../lookergetdashboards.go | 11 +- .../lookergetdimensions.go | 17 +- .../lookergetexplores/lookergetexplores.go | 13 +- .../lookerhealthpulse/lookerhealthpulse.go | 13 +- .../lookerhealthvacuum/lookerhealthvacuum.go | 23 +- .../lookermakedashboard.go | 19 +- .../looker/lookermakelook/lookermakelook.go | 23 +- .../tools/looker/lookerquery/lookerquery.go | 15 +- .../looker/lookerquerysql/lookerquerysql.go | 13 +- .../looker/lookerqueryurl/lookerqueryurl.go | 18 +- .../lookerrundashboard/lookerrundashboard.go | 11 +- .../looker/lookerrunlook/lookerrunlook.go | 15 +- .../lookerupdateprojectfile.go | 16 +- .../mindsdbexecutesql/mindsdbexecutesql.go | 19 +- .../tools/mindsdb/mindsdbsql/mindsdbsql.go | 19 +- .../mongodbaggregate/mongodbaggregate.go | 14 +- .../mongodbdeletemany/mongodbdeletemany.go | 14 +- .../mongodbdeleteone/mongodbdeleteone.go | 16 +- .../tools/mongodb/mongodbfind/mongodbfind.go | 17 +- .../mongodb/mongodbfindone/mongodbfindone.go | 18 +- .../mongodbinsertmany/mongodbinsertmany.go | 17 +- .../mongodbinsertone/mongodbinsertone.go | 17 +- .../mongodbupdatemany/mongodbupdatemany.go | 16 +- .../mongodbupdateone/mongodbupdateone.go | 16 +- .../mssql/mssqlexecutesql/mssqlexecutesql.go | 19 +- .../mssql/mssqllisttables/mssqllisttables.go | 504 +++++++++--------- internal/tools/mssql/mssqlsql/mssqlsql.go | 16 +- .../mysql/mysqlexecutesql/mysqlexecutesql.go | 19 +- .../mysqlgetqueryplan/mysqlgetqueryplan.go | 25 +- .../mysqllistactivequeries.go | 129 ++--- .../mysqllisttablefragmentation.go | 59 +- .../mysql/mysqllisttables/mysqllisttables.go | 14 +- .../mysqllisttablesmissinguniqueindexes.go | 57 +- internal/tools/mysql/mysqlsql/mysqlsql.go | 16 +- .../tools/neo4j/neo4jcypher/neo4jcypher.go | 12 +- .../neo4jexecutecypher/neo4jexecutecypher.go | 18 +- .../tools/neo4j/neo4jschema/neo4jschema.go | 32 +- .../oceanbaseexecutesql.go | 14 +- .../oceanbase/oceanbasesql/oceanbasesql.go | 16 +- .../oracleexecutesql/oracleexecutesql.go | 15 +- internal/tools/oracle/oraclesql/oraclesql.go | 16 +- .../postgresdatabaseoverview.go | 25 +- .../postgresexecutesql/postgresexecutesql.go | 19 +- .../postgresgetcolumncardinality.go | 19 +- .../postgreslistactivequeries.go | 62 +-- .../postgreslistavailableextensions.go | 29 +- .../postgreslistdatabasestats.go | 132 ++--- .../postgreslistindexes.go | 103 ++-- .../postgreslistinstalledextensions.go | 52 +- .../postgreslistlocks/postgreslistlocks.go | 19 +- .../postgreslistpgsettings.go | 27 +- .../postgreslistpublicationtables.go | 71 +-- .../postgreslistquerystats.go | 19 +- .../postgreslistroles/postgreslistroles.go | 95 ++-- .../postgreslistschemas.go | 53 +- .../postgreslistsequences.go | 49 +- .../postgresliststoredprocedure.go | 15 +- .../postgreslisttables/postgreslisttables.go | 166 +++--- .../postgreslisttablespaces.go | 65 +-- .../postgreslisttablestats.go | 35 +- .../postgreslisttriggers.go | 103 ++-- .../postgreslistviews/postgreslistviews.go | 53 +- .../postgreslongrunningtransactions.go | 19 +- .../postgresreplicationstats.go | 25 +- .../tools/postgres/postgressql/postgressql.go | 19 +- 74 files changed, 1504 insertions(+), 1270 deletions(-) diff --git a/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go b/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go index 12387da8b4..3cf828d30e 100644 --- a/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go +++ b/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go @@ -17,9 +17,11 @@ package elasticsearchesql import ( "context" "fmt" + "net/http" "time" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/goccy/go-yaml" @@ -89,10 +91,10 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } var cancel context.CancelFunc @@ -119,11 +121,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para for _, param := range t.Parameters { if param.GetType() == "array" { - return nil, fmt.Errorf("array parameters are not supported yet") + return nil, util.NewAgentError("array parameters are not supported yet", nil) } sqlParams = append(sqlParams, map[string]any{param.GetName(): paramMap[param.GetName()]}) } - return source.RunSQL(ctx, t.Format, query, sqlParams) + resp, err := source.RunSQL(ctx, t.Format, query, sqlParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/looker/lookerdeleteprojectfile/lookerdeleteprojectfile.go b/internal/tools/looker/lookerdeleteprojectfile/lookerdeleteprojectfile.go index 7644adee6a..741b6bb220 100644 --- a/internal/tools/looker/lookerdeleteprojectfile/lookerdeleteprojectfile.go +++ b/internal/tools/looker/lookerdeleteprojectfile/lookerdeleteprojectfile.go @@ -16,12 +16,14 @@ package lookerdeleteprojectfile import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/looker-open-source/sdk-codegen/go/rtl" @@ -111,30 +113,30 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } mapParams := params.AsMap() projectId, ok := mapParams["project_id"].(string) if !ok { - return nil, fmt.Errorf("'project_id' must be a string, got %T", mapParams["project_id"]) + return nil, util.NewAgentError(fmt.Sprintf("'project_id' must be a string, got %T", mapParams["project_id"]), nil) } filePath, ok := mapParams["file_path"].(string) if !ok { - return nil, fmt.Errorf("'file_path' must be a string, got %T", mapParams["file_path"]) + return nil, util.NewAgentError(fmt.Sprintf("'file_path' must be a string, got %T", mapParams["file_path"]), nil) } err = lookercommon.DeleteProjectFile(sdk, projectId, filePath, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making delete_project_file request: %s", err) + return nil, util.ProcessGeneralError(err) } data := make(map[string]any) diff --git a/internal/tools/looker/lookerdevmode/lookerdevmode.go b/internal/tools/looker/lookerdevmode/lookerdevmode.go index ea16d4a7ad..274062052f 100644 --- a/internal/tools/looker/lookerdevmode/lookerdevmode.go +++ b/internal/tools/looker/lookerdevmode/lookerdevmode.go @@ -16,6 +16,7 @@ package lookerdevmode import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -60,7 +61,6 @@ type Config struct { Annotations *tools.ToolAnnotations `yaml:"annotations,omitempty"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -81,7 +81,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) - // finish tool setup return Tool{ Config: cfg, Parameters: params, @@ -94,7 +93,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -108,25 +106,25 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } mapParams := params.AsMap() devMode, ok := mapParams["devMode"].(bool) if !ok { - return nil, fmt.Errorf("'devMode' must be a boolean, got %T", mapParams["devMode"]) + return nil, util.NewAgentError(fmt.Sprintf("'devMode' must be a boolean, got %T", mapParams["devMode"]), nil) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } var devModeString string if devMode { @@ -139,7 +137,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := sdk.UpdateSession(req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error setting/resetting dev mode: %w", err) + return nil, util.ProcessGeneralError(err) } logger.DebugContext(ctx, "result = ", resp) diff --git a/internal/tools/looker/lookergenerateembedurl/lookergenerateembedurl.go b/internal/tools/looker/lookergenerateembedurl/lookergenerateembedurl.go index 9ffc6f2f8e..908aae9e18 100644 --- a/internal/tools/looker/lookergenerateembedurl/lookergenerateembedurl.go +++ b/internal/tools/looker/lookergenerateembedurl/lookergenerateembedurl.go @@ -4,7 +4,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -17,6 +17,7 @@ package lookergenerateembedurl import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -114,15 +115,15 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } paramsMap := params.AsMap() embedType := paramsMap["type"].(string) @@ -138,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } forceLogoutLogin := true @@ -151,7 +152,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para logger.ErrorContext(ctx, "Making request %v", req) resp, err := sdk.CreateEmbedUrlAsMe(req, nil) if err != nil { - return nil, fmt.Errorf("error making create_embed_url_as_me request: %s", err) + return nil, util.ProcessGeneralError(err) } logger.ErrorContext(ctx, "Got response %v", resp) diff --git a/internal/tools/looker/lookergetconnectiondatabases/lookergetconnectiondatabases.go b/internal/tools/looker/lookergetconnectiondatabases/lookergetconnectiondatabases.go index 81c156f78b..c62fcef1f1 100644 --- a/internal/tools/looker/lookergetconnectiondatabases/lookergetconnectiondatabases.go +++ b/internal/tools/looker/lookergetconnectiondatabases/lookergetconnectiondatabases.go @@ -16,11 +16,13 @@ package lookergetconnectiondatabases import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/looker-open-source/sdk-codegen/go/rtl" @@ -107,27 +109,26 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } mapParams := params.AsMap() conn, ok := mapParams["conn"].(string) if !ok { - return nil, fmt.Errorf("'conn' must be a string, got %T", mapParams["conn"]) + return nil, util.NewAgentError(fmt.Sprintf("'conn' must be a string, got %T", mapParams["conn"]), nil) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } resp, err := sdk.ConnectionDatabases(conn, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_connection_databases request: %s", err) + return nil, util.ProcessGeneralError(err) } - //logger.DebugContext(ctx, "Got response of %v\n", resp) return resp, nil } diff --git a/internal/tools/looker/lookergetconnections/lookergetconnections.go b/internal/tools/looker/lookergetconnections/lookergetconnections.go index e223df0c04..585b30f8ca 100644 --- a/internal/tools/looker/lookergetconnections/lookergetconnections.go +++ b/internal/tools/looker/lookergetconnections/lookergetconnections.go @@ -16,6 +16,7 @@ package lookergetconnections import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -107,24 +108,24 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } resp, err := sdk.AllConnections("name, dialect(name), database, schema", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_connections request: %s", err) + return nil, util.ProcessGeneralError(err) } var data []any @@ -140,7 +141,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } conn, err := sdk.ConnectionFeatures(*v.Name, "multiple_databases", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_connection_features request: %s", err) + return nil, util.ProcessGeneralError(err) } vMap["supports_multiple_databases"] = *conn.MultipleDatabases logger.DebugContext(ctx, "Converted to %v\n", vMap) diff --git a/internal/tools/looker/lookergetconnectionschemas/lookergetconnectionschemas.go b/internal/tools/looker/lookergetconnectionschemas/lookergetconnectionschemas.go index e7385d1c64..3528b08a2e 100644 --- a/internal/tools/looker/lookergetconnectionschemas/lookergetconnectionschemas.go +++ b/internal/tools/looker/lookergetconnectionschemas/lookergetconnectionschemas.go @@ -16,11 +16,13 @@ package lookergetconnectionschemas import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/looker-open-source/sdk-codegen/go/rtl" @@ -108,22 +110,22 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } mapParams := params.AsMap() conn, ok := mapParams["conn"].(string) if !ok { - return nil, fmt.Errorf("'conn' must be a string, got %T", mapParams["conn"]) + return nil, util.NewAgentError(fmt.Sprintf("'conn' must be a string, got %T", mapParams["conn"]), nil) } db, _ := mapParams["db"].(string) sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } req := v4.RequestConnectionSchemas{ ConnectionName: conn, @@ -133,7 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := sdk.ConnectionSchemas(req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_connection_schemas request: %s", err) + return nil, util.ProcessGeneralError(err) } return resp, nil } diff --git a/internal/tools/looker/lookergetconnectiontablecolumns/lookergetconnectiontablecolumns.go b/internal/tools/looker/lookergetconnectiontablecolumns/lookergetconnectiontablecolumns.go index 263034e73a..61e3974c1a 100644 --- a/internal/tools/looker/lookergetconnectiontablecolumns/lookergetconnectiontablecolumns.go +++ b/internal/tools/looker/lookergetconnectiontablecolumns/lookergetconnectiontablecolumns.go @@ -8,7 +8,7 @@ // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY, either express or implied. +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package lookergetconnectiontablecolumns @@ -16,6 +16,7 @@ package lookergetconnectiontablecolumns import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -111,34 +112,34 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } mapParams := params.AsMap() conn, ok := mapParams["conn"].(string) if !ok { - return nil, fmt.Errorf("'conn' must be a string, got %T", mapParams["conn"]) + return nil, util.NewAgentError(fmt.Sprintf("'conn' must be a string, got %T", mapParams["conn"]), nil) } db, _ := mapParams["db"].(string) schema, ok := mapParams["schema"].(string) if !ok { - return nil, fmt.Errorf("'schema' must be a string, got %T", mapParams["schema"]) + return nil, util.NewAgentError(fmt.Sprintf("'schema' must be a string, got %T", mapParams["schema"]), nil) } tables, ok := mapParams["tables"].(string) if !ok { - return nil, fmt.Errorf("'tables' must be a string, got %T", mapParams["tables"]) + return nil, util.NewAgentError(fmt.Sprintf("'tables' must be a string, got %T", mapParams["tables"]), nil) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } req := v4.RequestConnectionColumns{ ConnectionName: conn, @@ -150,7 +151,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := sdk.ConnectionColumns(req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_connection_table_columns request: %s", err) + return nil, util.ProcessGeneralError(err) } var data []any for _, t := range resp { diff --git a/internal/tools/looker/lookergetconnectiontables/lookergetconnectiontables.go b/internal/tools/looker/lookergetconnectiontables/lookergetconnectiontables.go index 6d993f96a4..0a26997efe 100644 --- a/internal/tools/looker/lookergetconnectiontables/lookergetconnectiontables.go +++ b/internal/tools/looker/lookergetconnectiontables/lookergetconnectiontables.go @@ -16,6 +16,7 @@ package lookergetconnectiontables import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -110,30 +111,30 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } mapParams := params.AsMap() conn, ok := mapParams["conn"].(string) if !ok { - return nil, fmt.Errorf("'conn' must be a string, got %T", mapParams["conn"]) + return nil, util.NewAgentError(fmt.Sprintf("'conn' must be a string, got %T", mapParams["conn"]), nil) } db, _ := mapParams["db"].(string) schema, ok := mapParams["schema"].(string) if !ok { - return nil, fmt.Errorf("'schema' must be a string, got %T", mapParams["schema"]) + return nil, util.NewAgentError(fmt.Sprintf("'schema' must be a string, got %T", mapParams["schema"]), nil) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } req := v4.RequestConnectionTables{ ConnectionName: conn, @@ -144,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := sdk.ConnectionTables(req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_connection_tables request: %s", err) + return nil, util.ProcessGeneralError(err) } var data []any for _, s := range resp { diff --git a/internal/tools/looker/lookergetdashboards/lookergetdashboards.go b/internal/tools/looker/lookergetdashboards/lookergetdashboards.go index 64ac53783e..a4df73fbf7 100644 --- a/internal/tools/looker/lookergetdashboards/lookergetdashboards.go +++ b/internal/tools/looker/lookergetdashboards/lookergetdashboards.go @@ -16,6 +16,7 @@ package lookergetdashboards import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -116,15 +117,15 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } paramsMap := params.AsMap() title := paramsMap["title"].(string) @@ -142,7 +143,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } req := v4.RequestSearchDashboards{ Title: title_ptr, @@ -153,7 +154,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para logger.ErrorContext(ctx, "Making request %v", req) resp, err := sdk.SearchDashboards(req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_dashboards request: %s", err) + return nil, util.ProcessGeneralError(err) } logger.ErrorContext(ctx, "Got response %v", resp) var data []any diff --git a/internal/tools/looker/lookergetdimensions/lookergetdimensions.go b/internal/tools/looker/lookergetdimensions/lookergetdimensions.go index 3494207128..3373fe7db4 100644 --- a/internal/tools/looker/lookergetdimensions/lookergetdimensions.go +++ b/internal/tools/looker/lookergetdimensions/lookergetdimensions.go @@ -16,6 +16,7 @@ package lookergetdimensions import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -109,24 +110,24 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } model, explore, err := lookercommon.ProcessFieldArgs(ctx, params) if err != nil { - return nil, fmt.Errorf("error processing model or explore: %w", err) + return nil, util.NewAgentError("error processing model or explore", err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } fields := lookercommon.DimensionsFields req := v4.RequestLookmlModelExplore{ @@ -136,16 +137,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := sdk.LookmlModelExplore(req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_dimensions request: %w", err) + return nil, util.ProcessGeneralError(err) } if err := lookercommon.CheckLookerExploreFields(&resp); err != nil { - return nil, fmt.Errorf("error processing get_dimensions response: %w", err) + return nil, util.ProcessGeneralError(err) } data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Dimensions, source.LookerShowHiddenFields()) if err != nil { - return nil, fmt.Errorf("error extracting get_dimensions response: %w", err) + return nil, util.ProcessGeneralError(err) } logger.DebugContext(ctx, "data = ", data) diff --git a/internal/tools/looker/lookergetexplores/lookergetexplores.go b/internal/tools/looker/lookergetexplores/lookergetexplores.go index ea5c83e45f..2a58ffdcde 100644 --- a/internal/tools/looker/lookergetexplores/lookergetexplores.go +++ b/internal/tools/looker/lookergetexplores/lookergetexplores.go @@ -16,6 +16,7 @@ package lookergetexplores import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -109,29 +110,29 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } mapParams := params.AsMap() model, ok := mapParams["model"].(string) if !ok { - return nil, fmt.Errorf("'model' must be a string, got %T", mapParams["model"]) + return nil, util.NewAgentError(fmt.Sprintf("'model' must be a string, got %T", mapParams["model"]), nil) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } resp, err := sdk.LookmlModel(model, "explores(name,description,label,group_label,hidden)", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making get_explores request: %s", err) + return nil, util.ProcessGeneralError(err) } var data []any diff --git a/internal/tools/looker/lookerhealthpulse/lookerhealthpulse.go b/internal/tools/looker/lookerhealthpulse/lookerhealthpulse.go index fd5c3ead21..6d0177ecf5 100644 --- a/internal/tools/looker/lookerhealthpulse/lookerhealthpulse.go +++ b/internal/tools/looker/lookerhealthpulse/lookerhealthpulse.go @@ -17,6 +17,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "strings" yaml "github.com/goccy/go-yaml" @@ -116,20 +117,20 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } pulseTool := &pulseTool{ @@ -140,7 +141,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para paramsMap := params.AsMap() action, ok := paramsMap["action"].(string) if !ok { - return nil, fmt.Errorf("action parameter not found") + return nil, util.NewAgentError("action parameter not found", nil) } pulseParams := PulseParams{ @@ -149,7 +150,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para result, err := pulseTool.RunPulse(ctx, source, pulseParams) if err != nil { - return nil, fmt.Errorf("error running pulse: %w", err) + return nil, util.ProcessGeneralError(err) } logger.DebugContext(ctx, "result = ", result) diff --git a/internal/tools/looker/lookerhealthvacuum/lookerhealthvacuum.go b/internal/tools/looker/lookerhealthvacuum/lookerhealthvacuum.go index e0a0580b1d..c3b5658d57 100644 --- a/internal/tools/looker/lookerhealthvacuum/lookerhealthvacuum.go +++ b/internal/tools/looker/lookerhealthvacuum/lookerhealthvacuum.go @@ -17,6 +17,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "regexp" "strings" @@ -125,15 +126,15 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } paramsMap := params.AsMap() @@ -154,21 +155,29 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para action, ok := paramsMap["action"].(string) if !ok { - return nil, fmt.Errorf("action parameter not found") + return nil, util.NewAgentError("action parameter not found", nil) } + var res []map[string]interface{} + switch action { case "models": project, _ := paramsMap["project"].(string) model, _ := paramsMap["model"].(string) - return vacuumTool.models(ctx, project, model) + res, err = vacuumTool.models(ctx, project, model) case "explores": model, _ := paramsMap["model"].(string) explore, _ := paramsMap["explore"].(string) - return vacuumTool.explores(ctx, model, explore) + res, err = vacuumTool.explores(ctx, model, explore) default: - return nil, fmt.Errorf("unknown action: %s", action) + return nil, util.NewAgentError(fmt.Sprintf("unknown action: %s", action), nil) } + + if err != nil { + return nil, util.ProcessGeneralError(err) + } + + return res, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/looker/lookermakedashboard/lookermakedashboard.go b/internal/tools/looker/lookermakedashboard/lookermakedashboard.go index e4da154180..8934e46dc9 100644 --- a/internal/tools/looker/lookermakedashboard/lookermakedashboard.go +++ b/internal/tools/looker/lookermakedashboard/lookermakedashboard.go @@ -17,6 +17,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "slices" yaml "github.com/goccy/go-yaml" @@ -116,21 +117,21 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } logger.DebugContext(ctx, "params = ", params) sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } paramsMap := params.AsMap() @@ -141,19 +142,19 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para mrespFields := "id,personal_folder_id" mresp, err := sdk.Me(mrespFields, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making me request: %s", err) + return nil, util.ProcessGeneralError(err) } if folder == "" { if mresp.PersonalFolderId == nil || *mresp.PersonalFolderId == "" { - return nil, fmt.Errorf("user does not have a personal folder. A folder must be specified") + return nil, util.NewAgentError("user does not have a personal folder. A folder must be specified", nil) } folder = *mresp.PersonalFolderId } dashs, err := sdk.FolderDashboards(folder, "title", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error getting existing dashboards in folder: %s", err) + return nil, util.ProcessGeneralError(err) } dashTitles := []string{} @@ -162,7 +163,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } if slices.Contains(dashTitles, title) { lt, _ := json.Marshal(dashTitles) - return nil, fmt.Errorf("title %s already used in folder. Currently used titles are %v. Make the call again with a unique title", title, string(lt)) + return nil, util.NewAgentError(fmt.Sprintf("title %s already used in folder. Currently used titles are %v. Make the call again with a unique title", title, string(lt)), nil) } wd := v4.WriteDashboard{ @@ -172,7 +173,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := sdk.CreateDashboard(wd, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making create dashboard request: %s", err) + return nil, util.ProcessGeneralError(err) } logger.DebugContext(ctx, "resp = %v", resp) diff --git a/internal/tools/looker/lookermakelook/lookermakelook.go b/internal/tools/looker/lookermakelook/lookermakelook.go index 46d6d61841..3dcaf91716 100644 --- a/internal/tools/looker/lookermakelook/lookermakelook.go +++ b/internal/tools/looker/lookermakelook/lookermakelook.go @@ -17,6 +17,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "slices" yaml "github.com/goccy/go-yaml" @@ -123,25 +124,25 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } logger.DebugContext(ctx, "params = ", params) wq, err := lookercommon.ProcessQueryArgs(ctx, params) if err != nil { - return nil, fmt.Errorf("error building query request: %w", err) + return nil, util.NewAgentError("error building query request", err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } paramsMap := params.AsMap() title := paramsMap["title"].(string) @@ -152,19 +153,19 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para mrespFields := "id,personal_folder_id" mresp, err := sdk.Me(mrespFields, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making me request: %s", err) + return nil, util.ProcessGeneralError(err) } if folder == "" { if mresp.PersonalFolderId == nil || *mresp.PersonalFolderId == "" { - return nil, fmt.Errorf("user does not have a personal folder. A folder must be specified") + return nil, util.NewAgentError("user does not have a personal folder. A folder must be specified", nil) } folder = *mresp.PersonalFolderId } looks, err := sdk.FolderLooks(folder, "title", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error getting existing looks in folder: %s", err) + return nil, util.ProcessGeneralError(err) } lookTitles := []string{} @@ -173,7 +174,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } if slices.Contains(lookTitles, title) { lt, _ := json.Marshal(lookTitles) - return nil, fmt.Errorf("title %s already used in folder. Currently used titles are %v. Make the call again with a unique title", title, string(lt)) + return nil, util.NewAgentError(fmt.Sprintf("title %s already used in folder. Currently used titles are %v. Make the call again with a unique title", title, string(lt)), nil) } wq.VisConfig = &visConfig @@ -181,7 +182,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para qrespFields := "id" qresp, err := sdk.CreateQuery(*wq, qrespFields, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making create query request: %s", err) + return nil, util.ProcessGeneralError(err) } wlwq := v4.WriteLookWithQuery{ @@ -193,7 +194,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := sdk.CreateLook(wlwq, "", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making create look request: %s", err) + return nil, util.ProcessGeneralError(err) } logger.DebugContext(ctx, "resp = %v", resp) diff --git a/internal/tools/looker/lookerquery/lookerquery.go b/internal/tools/looker/lookerquery/lookerquery.go index 1fb6e43f1e..2d099f7519 100644 --- a/internal/tools/looker/lookerquery/lookerquery.go +++ b/internal/tools/looker/lookerquery/lookerquery.go @@ -17,6 +17,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -109,27 +110,27 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } wq, err := lookercommon.ProcessQueryArgs(ctx, params) if err != nil { - return nil, fmt.Errorf("error building WriteQuery request: %w", err) + return nil, util.NewAgentError("error building WriteQuery request", err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } resp, err := lookercommon.RunInlineQuery(ctx, sdk, wq, "json", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making query request: %s", err) + return nil, util.ProcessGeneralError(err) } logger.DebugContext(ctx, "resp = ", resp) @@ -137,7 +138,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para var data []any e := json.Unmarshal([]byte(resp), &data) if e != nil { - return nil, fmt.Errorf("error unmarshaling query response: %s", e) + return nil, util.NewClientServerError("error unmarshaling query response", http.StatusInternalServerError, e) } logger.DebugContext(ctx, "data = ", data) diff --git a/internal/tools/looker/lookerquerysql/lookerquerysql.go b/internal/tools/looker/lookerquerysql/lookerquerysql.go index c796b4e90c..577b67678e 100644 --- a/internal/tools/looker/lookerquerysql/lookerquerysql.go +++ b/internal/tools/looker/lookerquerysql/lookerquerysql.go @@ -16,6 +16,7 @@ package lookerquerysql import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -108,27 +109,27 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } wq, err := lookercommon.ProcessQueryArgs(ctx, params) if err != nil { - return nil, fmt.Errorf("error building query request: %w", err) + return nil, util.NewAgentError("error building query request", err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } resp, err := lookercommon.RunInlineQuery(ctx, sdk, wq, "sql", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making query request: %s", err) + return nil, util.ProcessGeneralError(err) } logger.DebugContext(ctx, "resp = ", resp) diff --git a/internal/tools/looker/lookerqueryurl/lookerqueryurl.go b/internal/tools/looker/lookerqueryurl/lookerqueryurl.go index 566745307c..3b9050b796 100644 --- a/internal/tools/looker/lookerqueryurl/lookerqueryurl.go +++ b/internal/tools/looker/lookerqueryurl/lookerqueryurl.go @@ -16,6 +16,7 @@ package lookerqueryurl import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -115,34 +116,37 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } logger.DebugContext(ctx, "params = ", params) wq, err := lookercommon.ProcessQueryArgs(ctx, params) if err != nil { - return nil, fmt.Errorf("error building query request: %w", err) + return nil, util.NewAgentError("error building query request", err) } paramsMap := params.AsMap() - visConfig := paramsMap["vis_config"].(map[string]any) + visConfig, ok := paramsMap["vis_config"].(map[string]any) + if !ok { + visConfig = make(map[string]any) + } wq.VisConfig = &visConfig sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } respFields := "id,slug,share_url,expanded_share_url" resp, err := sdk.CreateQuery(*wq, respFields, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making query request: %s", err) + return nil, util.ProcessGeneralError(err) } logger.DebugContext(ctx, "resp = ", resp) diff --git a/internal/tools/looker/lookerrundashboard/lookerrundashboard.go b/internal/tools/looker/lookerrundashboard/lookerrundashboard.go index a40ccddc82..20c9433452 100644 --- a/internal/tools/looker/lookerrundashboard/lookerrundashboard.go +++ b/internal/tools/looker/lookerrundashboard/lookerrundashboard.go @@ -17,6 +17,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "sync" yaml "github.com/goccy/go-yaml" @@ -114,15 +115,15 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } logger.DebugContext(ctx, "params = ", params) paramsMap := params.AsMap() @@ -131,11 +132,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } dashboard, err := sdk.Dashboard(dashboard_id, "", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error getting dashboard: %w", err) + return nil, util.ProcessGeneralError(err) } data := make(map[string]any) diff --git a/internal/tools/looker/lookerrunlook/lookerrunlook.go b/internal/tools/looker/lookerrunlook/lookerrunlook.go index 2ab69a36a4..5dec4ae308 100644 --- a/internal/tools/looker/lookerrunlook/lookerrunlook.go +++ b/internal/tools/looker/lookerrunlook/lookerrunlook.go @@ -17,6 +17,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -115,15 +116,15 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + return nil, util.NewClientServerError("unable to get logger from ctx", http.StatusInternalServerError, err) } logger.DebugContext(ctx, "params = ", params) paramsMap := params.AsMap() @@ -134,12 +135,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } look, err := sdk.Look(look_id, "", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error getting look definition: %s", err) + return nil, util.ProcessGeneralError(err) } wq := v4.WriteQuery{ @@ -155,14 +156,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para resp, err := lookercommon.RunInlineQuery(ctx, sdk, &wq, "json", source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making run_look request: %s", err) + return nil, util.ProcessGeneralError(err) } logger.DebugContext(ctx, "resp = ", resp) var data []any e := json.Unmarshal([]byte(resp), &data) if e != nil { - return nil, fmt.Errorf("error Unmarshaling run_look response: %s", e) + return nil, util.NewClientServerError("error Unmarshaling run_look response", http.StatusInternalServerError, e) } logger.DebugContext(ctx, "data = ", data) diff --git a/internal/tools/looker/lookerupdateprojectfile/lookerupdateprojectfile.go b/internal/tools/looker/lookerupdateprojectfile/lookerupdateprojectfile.go index 284872b2cf..35ea7328b1 100644 --- a/internal/tools/looker/lookerupdateprojectfile/lookerupdateprojectfile.go +++ b/internal/tools/looker/lookerupdateprojectfile/lookerupdateprojectfile.go @@ -16,12 +16,14 @@ package lookerupdateprojectfile import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/looker-open-source/sdk-codegen/go/rtl" @@ -111,29 +113,29 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } sdk, err := source.GetLookerSDK(string(accessToken)) if err != nil { - return nil, fmt.Errorf("error getting sdk: %w", err) + return nil, util.NewClientServerError("error getting sdk", http.StatusInternalServerError, err) } mapParams := params.AsMap() projectId, ok := mapParams["project_id"].(string) if !ok { - return nil, fmt.Errorf("'project_id' must be a string, got %T", mapParams["project_id"]) + return nil, util.NewAgentError(fmt.Sprintf("'project_id' must be a string, got %T", mapParams["project_id"]), nil) } filePath, ok := mapParams["file_path"].(string) if !ok { - return nil, fmt.Errorf("'file_path' must be a string, got %T", mapParams["file_path"]) + return nil, util.NewAgentError(fmt.Sprintf("'file_path' must be a string, got %T", mapParams["file_path"]), nil) } fileContent, ok := mapParams["file_content"].(string) if !ok { - return nil, fmt.Errorf("'file_content' must be a string, got %T", mapParams["file_content"]) + return nil, util.NewAgentError(fmt.Sprintf("'file_content' must be a string, got %T", mapParams["file_content"]), nil) } req := lookercommon.FileContent{ @@ -143,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para err = lookercommon.UpdateProjectFile(sdk, projectId, req, source.LookerApiSettings()) if err != nil { - return nil, fmt.Errorf("error making update_project_file request: %s", err) + return nil, util.ProcessGeneralError(err) } data := make(map[string]any) diff --git a/internal/tools/mindsdb/mindsdbexecutesql/mindsdbexecutesql.go b/internal/tools/mindsdb/mindsdbexecutesql/mindsdbexecutesql.go index bcbd5fba9a..9e0e3c644a 100644 --- a/internal/tools/mindsdb/mindsdbexecutesql/mindsdbexecutesql.go +++ b/internal/tools/mindsdb/mindsdbexecutesql/mindsdbexecutesql.go @@ -18,11 +18,13 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -55,7 +57,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -73,7 +74,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) InputSchema: inputSchema, } - // finish tool setup t := Tool{ Config: cfg, Parameters: params, @@ -83,7 +83,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return t, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -97,19 +96,23 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() - sql, ok := paramsMap["sql"].(string) + sqlStr, ok := paramsMap["sql"].(string) if !ok { - return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to get cast %s", paramsMap["sql"]), nil) } - return source.RunSQL(ctx, sql, nil) + resp, err := source.RunSQL(ctx, sqlStr, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mindsdb/mindsdbsql/mindsdbsql.go b/internal/tools/mindsdb/mindsdbsql/mindsdbsql.go index 0eb1c8eea7..782846ca6b 100644 --- a/internal/tools/mindsdb/mindsdbsql/mindsdbsql.go +++ b/internal/tools/mindsdb/mindsdbsql/mindsdbsql.go @@ -18,11 +18,13 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -58,7 +60,6 @@ type Config struct { TemplateParameters parameters.Parameters `yaml:"templateParameters"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -79,7 +80,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) InputSchema: paramMcpManifest, } - // finish tool setup t := Tool{ Config: cfg, AllParams: allParameters, @@ -89,7 +89,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return t, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -99,25 +98,29 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, newStatement, sliceParams) + resp, err := source.RunSQL(ctx, newStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go b/internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go index 24dce16680..519e79b696 100644 --- a/internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go +++ b/internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go @@ -16,10 +16,12 @@ package mongodbaggregate import ( "context" "fmt" + "net/http" "slices" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/v2/mongo" @@ -102,18 +104,22 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() pipelineString, err := parameters.PopulateTemplateWithJSON("MongoDBAggregatePipeline", t.PipelinePayload, paramsMap) if err != nil { - return nil, fmt.Errorf("error populating pipeline: %s", err) + return nil, util.NewAgentError("error populating pipeline", err) } - return source.Aggregate(ctx, pipelineString, t.Canonical, t.ReadOnly, t.Database, t.Collection) + resp, err := source.Aggregate(ctx, pipelineString, t.Canonical, t.ReadOnly, t.Database, t.Collection) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go b/internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go index 0d6f8c2be8..8f67aec41f 100644 --- a/internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go +++ b/internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go @@ -16,10 +16,12 @@ package mongodbdeletemany import ( "context" "fmt" + "net/http" "slices" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/v2/mongo" @@ -106,18 +108,22 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBDeleteManyFilter", t.FilterPayload, paramsMap) if err != nil { - return nil, fmt.Errorf("error populating filter: %s", err) + return nil, util.NewAgentError("error populating filter", err) } - return source.DeleteMany(ctx, filterString, t.Database, t.Collection) + resp, err := source.DeleteMany(ctx, filterString, t.Database, t.Collection) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go b/internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go index 416a67ffe3..55f855c7e0 100644 --- a/internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go +++ b/internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go @@ -9,17 +9,19 @@ // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and +// See the License for the language governing permissions and // limitations under the License. package mongodbdeleteone import ( "context" "fmt" + "net/http" "slices" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/v2/mongo" @@ -106,19 +108,23 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBDeleteOneFilter", t.FilterPayload, paramsMap) if err != nil { - return nil, fmt.Errorf("error populating filter: %s", err) + return nil, util.NewAgentError("error populating filter", err) } - return source.DeleteOne(ctx, filterString, t.Database, t.Collection) + resp, err := source.DeleteOne(ctx, filterString, t.Database, t.Collection) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbfind/mongodbfind.go b/internal/tools/mongodb/mongodbfind/mongodbfind.go index 9389088f0f..9da5a8b30e 100644 --- a/internal/tools/mongodb/mongodbfind/mongodbfind.go +++ b/internal/tools/mongodb/mongodbfind/mongodbfind.go @@ -16,6 +16,7 @@ package mongodbfind import ( "context" "fmt" + "net/http" "slices" "github.com/goccy/go-yaml" @@ -121,7 +122,7 @@ type Tool struct { func getOptions(ctx context.Context, sortParameters parameters.Parameters, projectPayload string, limit int64, paramsMap map[string]any) (*options.FindOptionsBuilder, error) { logger, err := util.LoggerFromContext(ctx) if err != nil { - panic(err) + return nil, err } opts := options.Find() @@ -157,22 +158,26 @@ func getOptions(ctx context.Context, sortParameters parameters.Parameters, proje return opts, nil } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBFindFilterString", t.FilterPayload, paramsMap) if err != nil { - return nil, fmt.Errorf("error populating filter: %s", err) + return nil, util.NewAgentError("error populating filter", err) } opts, err := getOptions(ctx, t.SortParams, t.ProjectPayload, t.Limit, paramsMap) if err != nil { - return nil, fmt.Errorf("error populating options: %s", err) + return nil, util.NewAgentError("error populating options", err) } - return source.Find(ctx, filterString, t.Database, t.Collection, opts) + resp, err := source.Find(ctx, filterString, t.Database, t.Collection, opts) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbfindone/mongodbfindone.go b/internal/tools/mongodb/mongodbfindone/mongodbfindone.go index 3822d1302a..f75af4328a 100644 --- a/internal/tools/mongodb/mongodbfindone/mongodbfindone.go +++ b/internal/tools/mongodb/mongodbfindone/mongodbfindone.go @@ -16,10 +16,12 @@ package mongodbfindone import ( "context" "fmt" + "net/http" "slices" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo" @@ -110,32 +112,36 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBFindOneFilterString", t.FilterPayload, paramsMap) if err != nil { - return nil, fmt.Errorf("error populating filter: %s", err) + return nil, util.NewAgentError("error populating filter", err) } opts := options.FindOne() if len(t.ProjectPayload) > 0 { result, err := parameters.PopulateTemplateWithJSON("MongoDBFindOneProjectString", t.ProjectPayload, paramsMap) if err != nil { - return nil, fmt.Errorf("error populating project payload: %s", err) + return nil, util.NewAgentError("error populating project payload", err) } var projection any err = bson.UnmarshalExtJSON([]byte(result), false, &projection) if err != nil { - return nil, fmt.Errorf("error unmarshalling projection: %s", err) + return nil, util.NewAgentError("error unmarshalling projection", err) } opts = opts.SetProjection(projection) } - return source.FindOne(ctx, filterString, t.Database, t.Collection, opts) + resp, err := source.FindOne(ctx, filterString, t.Database, t.Collection, opts) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go b/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go index 17a8020635..0de1cc8de4 100644 --- a/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go +++ b/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go @@ -15,13 +15,14 @@ package mongodbinsertmany import ( "context" - "errors" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/v2/mongo" ) @@ -100,23 +101,27 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } if len(params) == 0 { - return nil, errors.New("no input found") + return nil, util.NewAgentError("no input found", nil) } paramsMap := params.AsMap() jsonData, ok := paramsMap[paramDataKey].(string) if !ok { - return nil, errors.New("no input found") + return nil, util.NewAgentError("no input found or invalid type for data", nil) } - return source.InsertMany(ctx, jsonData, t.Canonical, t.Database, t.Collection) + resp, err := source.InsertMany(ctx, jsonData, t.Canonical, t.Database, t.Collection) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go b/internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go index d4e9f7f072..3fc9260a51 100644 --- a/internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go +++ b/internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go @@ -15,13 +15,14 @@ package mongodbinsertone import ( "context" - "errors" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/v2/mongo" ) @@ -101,20 +102,24 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } if len(params) == 0 { - return nil, errors.New("no input found") + return nil, util.NewAgentError("no input found", nil) } // use the first, assume it's a string jsonData, ok := params[0].Value.(string) if !ok { - return nil, errors.New("no input found") + return nil, util.NewAgentError("no input found or invalid type for data", nil) } - return source.InsertOne(ctx, jsonData, t.Canonical, t.Database, t.Collection) + resp, err := source.InsertOne(ctx, jsonData, t.Canonical, t.Database, t.Collection) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go b/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go index d7a7cd569b..7e34e52384 100644 --- a/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go +++ b/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go @@ -16,12 +16,14 @@ package mongodbupdatemany import ( "context" "fmt" + "net/http" "slices" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/v2/mongo" ) @@ -109,22 +111,26 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBUpdateManyFilter", t.FilterPayload, paramsMap) if err != nil { - return nil, fmt.Errorf("error populating filter: %s", err) + return nil, util.NewAgentError("error populating filter", err) } updateString, err := parameters.PopulateTemplateWithJSON("MongoDBUpdateMany", t.UpdatePayload, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to get update: %w", err) + return nil, util.NewAgentError("unable to get update", err) } - return source.UpdateMany(ctx, filterString, t.Canonical, updateString, t.Database, t.Collection, t.Upsert) + resp, err := source.UpdateMany(ctx, filterString, t.Canonical, updateString, t.Database, t.Collection, t.Upsert) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go b/internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go index 2fa99efb67..6369e08a91 100644 --- a/internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go +++ b/internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go @@ -16,12 +16,14 @@ package mongodbupdateone import ( "context" "fmt" + "net/http" "slices" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/v2/mongo" ) @@ -110,22 +112,26 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBUpdateOneFilter", t.FilterPayload, paramsMap) if err != nil { - return nil, fmt.Errorf("error populating filter: %s", err) + return nil, util.NewAgentError("error populating filter", err) } updateString, err := parameters.PopulateTemplateWithJSON("MongoDBUpdateOne", t.UpdatePayload, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to get update: %w", err) + return nil, util.NewAgentError("unable to get update", err) } - return source.UpdateOne(ctx, filterString, t.Canonical, updateString, t.Database, t.Collection, t.Upsert) + resp, err := source.UpdateOne(ctx, filterString, t.Canonical, updateString, t.Database, t.Collection, t.Upsert) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go b/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go index 3b00090823..ae3d497b84 100644 --- a/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go +++ b/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -89,25 +90,29 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() - sql, ok := paramsMap["sql"].(string) + sqlStr, ok := paramsMap["sql"].(string) if !ok { - return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to get cast %s", paramsMap["sql"]), nil) } // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } - logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql)) - return source.RunSQL(ctx, sql, nil) + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sqlStr)) + resp, err := source.RunSQL(ctx, sqlStr, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mssql/mssqllisttables/mssqllisttables.go b/internal/tools/mssql/mssqllisttables/mssqllisttables.go index 6798087768..ea462e2740 100644 --- a/internal/tools/mssql/mssqllisttables/mssqllisttables.go +++ b/internal/tools/mssql/mssqllisttables/mssqllisttables.go @@ -18,263 +18,265 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) const resourceType string = "mssql-list-tables" const listTablesStatement = ` - WITH table_info AS ( - SELECT - t.object_id AS table_oid, - s.name AS schema_name, - t.name AS table_name, - dp.name AS table_owner, -- Schema's owner principal name - CAST(ep.value AS NVARCHAR(MAX)) AS table_comment, -- Cast for JSON compatibility - CASE - WHEN EXISTS ( -- Check if the table has more than one partition for any of its indexes or heap - SELECT 1 FROM sys.partitions p - WHERE p.object_id = t.object_id AND p.partition_number > 1 - ) THEN 'PARTITIONED TABLE' - ELSE 'TABLE' - END AS object_type_detail - FROM - sys.tables t - INNER JOIN - sys.schemas s ON t.schema_id = s.schema_id - LEFT JOIN - sys.database_principals dp ON s.principal_id = dp.principal_id - LEFT JOIN - sys.extended_properties ep ON ep.major_id = t.object_id AND ep.minor_id = 0 AND ep.class = 1 AND ep.name = 'MS_Description' - WHERE - t.type = 'U' -- User tables - AND s.name NOT IN ('sys', 'INFORMATION_SCHEMA', 'guest', 'db_owner', 'db_accessadmin', 'db_backupoperator', 'db_datareader', 'db_datawriter', 'db_ddladmin', 'db_denydatareader', 'db_denydatawriter', 'db_securityadmin') - AND (@table_names IS NULL OR LTRIM(RTRIM(@table_names)) = '' OR t.name IN (SELECT LTRIM(RTRIM(value)) FROM STRING_SPLIT(@table_names, ','))) - ), - columns_info AS ( - SELECT - c.object_id AS table_oid, - c.name AS column_name, - CONCAT( - UPPER(TY.name), -- Base type name - CASE - WHEN TY.name IN ('char', 'varchar', 'nchar', 'nvarchar', 'binary', 'varbinary') THEN - CONCAT('(', IIF(c.max_length = -1, 'MAX', CAST(c.max_length / CASE WHEN TY.name IN ('nchar', 'nvarchar') THEN 2 ELSE 1 END AS VARCHAR(10))), ')') - WHEN TY.name IN ('decimal', 'numeric') THEN - CONCAT('(', c.precision, ',', c.scale, ')') - WHEN TY.name IN ('datetime2', 'datetimeoffset', 'time') THEN - CONCAT('(', c.scale, ')') - ELSE '' - END - ) AS data_type, - c.column_id AS column_ordinal_position, - IIF(c.is_nullable = 0, CAST(1 AS BIT), CAST(0 AS BIT)) AS is_not_nullable, - dc.definition AS column_default, - CAST(epc.value AS NVARCHAR(MAX)) AS column_comment - FROM - sys.columns c - JOIN - table_info ti ON c.object_id = ti.table_oid - JOIN - sys.types TY ON c.user_type_id = TY.user_type_id AND TY.is_user_defined = 0 -- Ensure we get base types - LEFT JOIN - sys.default_constraints dc ON c.object_id = dc.parent_object_id AND c.column_id = dc.parent_column_id - LEFT JOIN - sys.extended_properties epc ON epc.major_id = c.object_id AND epc.minor_id = c.column_id AND epc.class = 1 AND epc.name = 'MS_Description' - ), - constraints_info AS ( - -- Primary Keys & Unique Constraints - SELECT - kc.parent_object_id AS table_oid, - kc.name AS constraint_name, - REPLACE(kc.type_desc, '_CONSTRAINT', '') AS constraint_type, -- 'PRIMARY_KEY', 'UNIQUE' - STUFF((SELECT ', ' + col.name - FROM sys.index_columns ic - JOIN sys.columns col ON ic.object_id = col.object_id AND ic.column_id = col.column_id - WHERE ic.object_id = kc.parent_object_id AND ic.index_id = kc.unique_index_id - ORDER BY ic.key_ordinal - FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') AS constraint_columns, - NULL AS foreign_key_referenced_table, - NULL AS foreign_key_referenced_columns, - CASE kc.type - WHEN 'PK' THEN 'PRIMARY KEY (' + STUFF((SELECT ', ' + col.name FROM sys.index_columns ic JOIN sys.columns col ON ic.object_id = col.object_id AND ic.column_id = col.column_id WHERE ic.object_id = kc.parent_object_id AND ic.index_id = kc.unique_index_id ORDER BY ic.key_ordinal FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') + ')' - WHEN 'UQ' THEN 'UNIQUE (' + STUFF((SELECT ', ' + col.name FROM sys.index_columns ic JOIN sys.columns col ON ic.object_id = col.object_id AND ic.column_id = col.column_id WHERE ic.object_id = kc.parent_object_id AND ic.index_id = kc.unique_index_id ORDER BY ic.key_ordinal FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') + ')' - END AS constraint_definition - FROM sys.key_constraints kc - JOIN table_info ti ON kc.parent_object_id = ti.table_oid - UNION ALL - -- Foreign Keys - SELECT - fk.parent_object_id AS table_oid, - fk.name AS constraint_name, - 'FOREIGN KEY' AS constraint_type, - STUFF((SELECT ', ' + pc.name - FROM sys.foreign_key_columns fkc - JOIN sys.columns pc ON fkc.parent_object_id = pc.object_id AND fkc.parent_column_id = pc.column_id - WHERE fkc.constraint_object_id = fk.object_id - ORDER BY fkc.constraint_column_id - FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') AS constraint_columns, - SCHEMA_NAME(rt.schema_id) + '.' + OBJECT_NAME(fk.referenced_object_id) AS foreign_key_referenced_table, - STUFF((SELECT ', ' + rc.name - FROM sys.foreign_key_columns fkc - JOIN sys.columns rc ON fkc.referenced_object_id = rc.object_id AND fkc.referenced_column_id = rc.column_id - WHERE fkc.constraint_object_id = fk.object_id - ORDER BY fkc.constraint_column_id - FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') AS foreign_key_referenced_columns, - OBJECT_DEFINITION(fk.object_id) AS constraint_definition - FROM sys.foreign_keys fk - JOIN sys.tables rt ON fk.referenced_object_id = rt.object_id - JOIN table_info ti ON fk.parent_object_id = ti.table_oid - UNION ALL - -- Check Constraints - SELECT - cc.parent_object_id AS table_oid, - cc.name AS constraint_name, - 'CHECK' AS constraint_type, - NULL AS constraint_columns, -- Definition includes column context - NULL AS foreign_key_referenced_table, - NULL AS foreign_key_referenced_columns, - cc.definition AS constraint_definition - FROM sys.check_constraints cc - JOIN table_info ti ON cc.parent_object_id = ti.table_oid - ), - indexes_info AS ( - SELECT - i.object_id AS table_oid, - i.name AS index_name, - i.type_desc AS index_method, -- CLUSTERED, NONCLUSTERED, XML, etc. - i.is_unique, - i.is_primary_key AS is_primary, - STUFF((SELECT ', ' + c.name - FROM sys.index_columns ic - JOIN sys.columns c ON i.object_id = c.object_id AND ic.column_id = c.column_id - WHERE ic.object_id = i.object_id AND ic.index_id = i.index_id AND ic.is_included_column = 0 - ORDER BY ic.key_ordinal - FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') AS index_columns, - ( - 'COLUMNS: (' + ISNULL(STUFF((SELECT ', ' + c.name + CASE WHEN ic.is_descending_key = 1 THEN ' DESC' ELSE '' END - FROM sys.index_columns ic - JOIN sys.columns c ON i.object_id = c.object_id AND ic.column_id = c.column_id - WHERE ic.object_id = i.object_id AND ic.index_id = i.index_id AND ic.is_included_column = 0 - ORDER BY ic.key_ordinal FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, ''), 'N/A') + ')' + - ISNULL(CHAR(13)+CHAR(10) + 'INCLUDE: (' + STUFF((SELECT ', ' + c.name - FROM sys.index_columns ic - JOIN sys.columns c ON i.object_id = c.object_id AND ic.column_id = c.column_id - WHERE ic.object_id = i.object_id AND ic.index_id = i.index_id AND ic.is_included_column = 1 - ORDER BY ic.index_column_id FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') + ')', '') + - ISNULL(CHAR(13)+CHAR(10) + 'FILTER: (' + i.filter_definition + ')', '') - ) AS index_definition_details - FROM - sys.indexes i - JOIN - table_info ti ON i.object_id = ti.table_oid - WHERE i.type <> 0 -- Exclude Heaps - AND i.name IS NOT NULL -- Exclude unnamed heap indexes; named indexes (PKs are often named) are preferred. - ), - triggers_info AS ( - SELECT - tr.parent_id AS table_oid, - tr.name AS trigger_name, - OBJECT_DEFINITION(tr.object_id) AS trigger_definition, - CASE tr.is_disabled WHEN 0 THEN 'ENABLED' ELSE 'DISABLED' END AS trigger_enabled_state - FROM - sys.triggers tr - JOIN - table_info ti ON tr.parent_id = ti.table_oid - WHERE - tr.is_ms_shipped = 0 - AND tr.parent_class_desc = 'OBJECT_OR_COLUMN' -- DML Triggers on tables/views - ) - SELECT - ti.schema_name, - ti.table_name AS object_name, - CASE - WHEN @output_format = 'simple' THEN - (SELECT ti.table_name AS name FOR JSON PATH, WITHOUT_ARRAY_WRAPPER) - ELSE - ( - SELECT - ti.schema_name AS schema_name, - ti.table_name AS object_name, - ti.object_type_detail AS object_type, - ti.table_owner AS owner, - ti.table_comment AS comment, - JSON_QUERY(ISNULL(( - SELECT - ci.column_name, - ci.data_type, - ci.column_ordinal_position, - ci.is_not_nullable, - ci.column_default, - ci.column_comment - FROM columns_info ci - WHERE ci.table_oid = ti.table_oid - ORDER BY ci.column_ordinal_position - FOR JSON PATH - ), '[]')) AS columns, - JSON_QUERY(ISNULL(( - SELECT - cons.constraint_name, - cons.constraint_type, - cons.constraint_definition, - JSON_QUERY( - CASE - WHEN cons.constraint_columns IS NOT NULL AND LTRIM(RTRIM(cons.constraint_columns)) <> '' - THEN '[' + (SELECT STRING_AGG('"' + LTRIM(RTRIM(value)) + '"', ',') FROM STRING_SPLIT(cons.constraint_columns, ',')) + ']' - ELSE '[]' - END - ) AS constraint_columns, - cons.foreign_key_referenced_table, - JSON_QUERY( - CASE - WHEN cons.foreign_key_referenced_columns IS NOT NULL AND LTRIM(RTRIM(cons.foreign_key_referenced_columns)) <> '' - THEN '[' + (SELECT STRING_AGG('"' + LTRIM(RTRIM(value)) + '"', ',') FROM STRING_SPLIT(cons.foreign_key_referenced_columns, ',')) + ']' - ELSE '[]' - END - ) AS foreign_key_referenced_columns - FROM constraints_info cons - WHERE cons.table_oid = ti.table_oid - FOR JSON PATH - ), '[]')) AS constraints, - JSON_QUERY(ISNULL(( - SELECT - ii.index_name, - ii.index_definition_details AS index_definition, - ii.is_unique, - ii.is_primary, - ii.index_method, - JSON_QUERY( - CASE - WHEN ii.index_columns IS NOT NULL AND LTRIM(RTRIM(ii.index_columns)) <> '' - THEN '[' + (SELECT STRING_AGG('"' + LTRIM(RTRIM(value)) + '"', ',') FROM STRING_SPLIT(ii.index_columns, ',')) + ']' - ELSE '[]' - END - ) AS index_columns - FROM indexes_info ii - WHERE ii.table_oid = ti.table_oid - FOR JSON PATH - ), '[]')) AS indexes, - JSON_QUERY(ISNULL(( - SELECT - tri.trigger_name, - tri.trigger_definition, - tri.trigger_enabled_state - FROM triggers_info tri - WHERE tri.table_oid = ti.table_oid - FOR JSON PATH - ), '[]')) AS triggers - FOR JSON PATH, WITHOUT_ARRAY_WRAPPER -- Creates a single JSON object for this table's details - ) - END AS object_details - FROM - table_info ti - ORDER BY - ti.schema_name, ti.table_name; + WITH table_info AS ( + SELECT + t.object_id AS table_oid, + s.name AS schema_name, + t.name AS table_name, + dp.name AS table_owner, -- Schema's owner principal name + CAST(ep.value AS NVARCHAR(MAX)) AS table_comment, -- Cast for JSON compatibility + CASE + WHEN EXISTS ( -- Check if the table has more than one partition for any of its indexes or heap + SELECT 1 FROM sys.partitions p + WHERE p.object_id = t.object_id AND p.partition_number > 1 + ) THEN 'PARTITIONED TABLE' + ELSE 'TABLE' + END AS object_type_detail + FROM + sys.tables t + INNER JOIN + sys.schemas s ON t.schema_id = s.schema_id + LEFT JOIN + sys.database_principals dp ON s.principal_id = dp.principal_id + LEFT JOIN + sys.extended_properties ep ON ep.major_id = t.object_id AND ep.minor_id = 0 AND ep.class = 1 AND ep.name = 'MS_Description' + WHERE + t.type = 'U' -- User tables + AND s.name NOT IN ('sys', 'INFORMATION_SCHEMA', 'guest', 'db_owner', 'db_accessadmin', 'db_backupoperator', 'db_datareader', 'db_datawriter', 'db_ddladmin', 'db_denydatareader', 'db_denydatawriter', 'db_securityadmin') + AND (@table_names IS NULL OR LTRIM(RTRIM(@table_names)) = '' OR t.name IN (SELECT LTRIM(RTRIM(value)) FROM STRING_SPLIT(@table_names, ','))) + ), + columns_info AS ( + SELECT + c.object_id AS table_oid, + c.name AS column_name, + CONCAT( + UPPER(TY.name), -- Base type name + CASE + WHEN TY.name IN ('char', 'varchar', 'nchar', 'nvarchar', 'binary', 'varbinary') THEN + CONCAT('(', IIF(c.max_length = -1, 'MAX', CAST(c.max_length / CASE WHEN TY.name IN ('nchar', 'nvarchar') THEN 2 ELSE 1 END AS VARCHAR(10))), ')') + WHEN TY.name IN ('decimal', 'numeric') THEN + CONCAT('(', c.precision, ',', c.scale, ')') + WHEN TY.name IN ('datetime2', 'datetimeoffset', 'time') THEN + CONCAT('(', c.scale, ')') + ELSE '' + END + ) AS data_type, + c.column_id AS column_ordinal_position, + IIF(c.is_nullable = 0, CAST(1 AS BIT), CAST(0 AS BIT)) AS is_not_nullable, + dc.definition AS column_default, + CAST(epc.value AS NVARCHAR(MAX)) AS column_comment + FROM + sys.columns c + JOIN + table_info ti ON c.object_id = ti.table_oid + JOIN + sys.types TY ON c.user_type_id = TY.user_type_id AND TY.is_user_defined = 0 -- Ensure we get base types + LEFT JOIN + sys.default_constraints dc ON c.object_id = dc.parent_object_id AND c.column_id = dc.parent_column_id + LEFT JOIN + sys.extended_properties epc ON epc.major_id = c.object_id AND epc.minor_id = c.column_id AND epc.class = 1 AND epc.name = 'MS_Description' + ), + constraints_info AS ( + -- Primary Keys & Unique Constraints + SELECT + kc.parent_object_id AS table_oid, + kc.name AS constraint_name, + REPLACE(kc.type_desc, '_CONSTRAINT', '') AS constraint_type, -- 'PRIMARY_KEY', 'UNIQUE' + STUFF((SELECT ', ' + col.name + FROM sys.index_columns ic + JOIN sys.columns col ON ic.object_id = col.object_id AND ic.column_id = col.column_id + WHERE ic.object_id = kc.parent_object_id AND ic.index_id = kc.unique_index_id + ORDER BY ic.key_ordinal + FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') AS constraint_columns, + NULL AS foreign_key_referenced_table, + NULL AS foreign_key_referenced_columns, + CASE kc.type + WHEN 'PK' THEN 'PRIMARY KEY (' + STUFF((SELECT ', ' + col.name FROM sys.index_columns ic JOIN sys.columns col ON ic.object_id = col.object_id AND ic.column_id = col.column_id WHERE ic.object_id = kc.parent_object_id AND ic.index_id = kc.unique_index_id ORDER BY ic.key_ordinal FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') + ')' + WHEN 'UQ' THEN 'UNIQUE (' + STUFF((SELECT ', ' + col.name FROM sys.index_columns ic JOIN sys.columns col ON ic.object_id = col.object_id AND ic.column_id = col.column_id WHERE ic.object_id = kc.parent_object_id AND ic.index_id = kc.unique_index_id ORDER BY ic.key_ordinal FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') + ')' + END AS constraint_definition + FROM sys.key_constraints kc + JOIN table_info ti ON kc.parent_object_id = ti.table_oid + UNION ALL + -- Foreign Keys + SELECT + fk.parent_object_id AS table_oid, + fk.name AS constraint_name, + 'FOREIGN KEY' AS constraint_type, + STUFF((SELECT ', ' + pc.name + FROM sys.foreign_key_columns fkc + JOIN sys.columns pc ON fkc.parent_object_id = pc.object_id AND fkc.parent_column_id = pc.column_id + WHERE fkc.constraint_object_id = fk.object_id + ORDER BY fkc.constraint_column_id + FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') AS constraint_columns, + SCHEMA_NAME(rt.schema_id) + '.' + OBJECT_NAME(fk.referenced_object_id) AS foreign_key_referenced_table, + STUFF((SELECT ', ' + rc.name + FROM sys.foreign_key_columns fkc + JOIN sys.columns rc ON fkc.referenced_object_id = rc.object_id AND fkc.referenced_column_id = rc.column_id + WHERE fkc.constraint_object_id = fk.object_id + ORDER BY fkc.constraint_column_id + FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') AS foreign_key_referenced_columns, + OBJECT_DEFINITION(fk.object_id) AS constraint_definition + FROM sys.foreign_keys fk + JOIN sys.tables rt ON fk.referenced_object_id = rt.object_id + JOIN table_info ti ON fk.parent_object_id = ti.table_oid + UNION ALL + -- Check Constraints + SELECT + cc.parent_object_id AS table_oid, + cc.name AS constraint_name, + 'CHECK' AS constraint_type, + NULL AS constraint_columns, -- Definition includes column context + NULL AS foreign_key_referenced_table, + NULL AS foreign_key_referenced_columns, + cc.definition AS constraint_definition + FROM sys.check_constraints cc + JOIN table_info ti ON cc.parent_object_id = ti.table_oid + ), + indexes_info AS ( + SELECT + i.object_id AS table_oid, + i.name AS index_name, + i.type_desc AS index_method, -- CLUSTERED, NONCLUSTERED, XML, etc. + i.is_unique, + i.is_primary_key AS is_primary, + STUFF((SELECT ', ' + c.name + FROM sys.index_columns ic + JOIN sys.columns c ON i.object_id = c.object_id AND ic.column_id = c.column_id + WHERE ic.object_id = i.object_id AND ic.index_id = i.index_id AND ic.is_included_column = 0 + ORDER BY ic.key_ordinal + FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') AS index_columns, + ( + 'COLUMNS: (' + ISNULL(STUFF((SELECT ', ' + c.name + CASE WHEN ic.is_descending_key = 1 THEN ' DESC' ELSE '' END + FROM sys.index_columns ic + JOIN sys.columns c ON i.object_id = c.object_id AND ic.column_id = c.column_id + WHERE ic.object_id = i.object_id AND ic.index_id = i.index_id AND ic.is_included_column = 0 + ORDER BY ic.key_ordinal FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, ''), 'N/A') + ')' + + ISNULL(CHAR(13)+CHAR(10) + 'INCLUDE: (' + STUFF((SELECT ', ' + c.name + FROM sys.index_columns ic + JOIN sys.columns c ON i.object_id = c.object_id AND ic.column_id = c.column_id + WHERE ic.object_id = i.object_id AND ic.index_id = i.index_id AND ic.is_included_column = 1 + ORDER BY ic.index_column_id FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 2, '') + ')', '') + + ISNULL(CHAR(13)+CHAR(10) + 'FILTER: (' + i.filter_definition + ')', '') + ) AS index_definition_details + FROM + sys.indexes i + JOIN + table_info ti ON i.object_id = ti.table_oid + WHERE i.type <> 0 -- Exclude Heaps + AND i.name IS NOT NULL -- Exclude unnamed heap indexes; named indexes (PKs are often named) are preferred. + ), + triggers_info AS ( + SELECT + tr.parent_id AS table_oid, + tr.name AS trigger_name, + OBJECT_DEFINITION(tr.object_id) AS trigger_definition, + CASE tr.is_disabled WHEN 0 THEN 'ENABLED' ELSE 'DISABLED' END AS trigger_enabled_state + FROM + sys.triggers tr + JOIN + table_info ti ON tr.parent_id = ti.table_oid + WHERE + tr.is_ms_shipped = 0 + AND tr.parent_class_desc = 'OBJECT_OR_COLUMN' -- DML Triggers on tables/views + ) + SELECT + ti.schema_name, + ti.table_name AS object_name, + CASE + WHEN @output_format = 'simple' THEN + (SELECT ti.table_name AS name FOR JSON PATH, WITHOUT_ARRAY_WRAPPER) + ELSE + ( + SELECT + ti.schema_name AS schema_name, + ti.table_name AS object_name, + ti.object_type_detail AS object_type, + ti.table_owner AS owner, + ti.table_comment AS comment, + JSON_QUERY(ISNULL(( + SELECT + ci.column_name, + ci.data_type, + ci.column_ordinal_position, + ci.is_not_nullable, + ci.column_default, + ci.column_comment + FROM columns_info ci + WHERE ci.table_oid = ti.table_oid + ORDER BY ci.column_ordinal_position + FOR JSON PATH + ), '[]')) AS columns, + JSON_QUERY(ISNULL(( + SELECT + cons.constraint_name, + cons.constraint_type, + cons.constraint_definition, + JSON_QUERY( + CASE + WHEN cons.constraint_columns IS NOT NULL AND LTRIM(RTRIM(cons.constraint_columns)) <> '' + THEN '[' + (SELECT STRING_AGG('"' + LTRIM(RTRIM(value)) + '"', ',') FROM STRING_SPLIT(cons.constraint_columns, ',')) + ']' + ELSE '[]' + END + ) AS constraint_columns, + cons.foreign_key_referenced_table, + JSON_QUERY( + CASE + WHEN cons.foreign_key_referenced_columns IS NOT NULL AND LTRIM(RTRIM(cons.foreign_key_referenced_columns)) <> '' + THEN '[' + (SELECT STRING_AGG('"' + LTRIM(RTRIM(value)) + '"', ',') FROM STRING_SPLIT(cons.foreign_key_referenced_columns, ',')) + ']' + ELSE '[]' + END + ) AS foreign_key_referenced_columns + FROM constraints_info cons + WHERE cons.table_oid = ti.table_oid + FOR JSON PATH + ), '[]')) AS constraints, + JSON_QUERY(ISNULL(( + SELECT + ii.index_name, + ii.index_definition_details AS index_definition, + ii.is_unique, + ii.is_primary, + ii.index_method, + JSON_QUERY( + CASE + WHEN ii.index_columns IS NOT NULL AND LTRIM(RTRIM(ii.index_columns)) <> '' + THEN '[' + (SELECT STRING_AGG('"' + LTRIM(RTRIM(value)) + '"', ',') FROM STRING_SPLIT(ii.index_columns, ',')) + ']' + ELSE '[]' + END + ) AS index_columns + FROM indexes_info ii + WHERE ii.table_oid = ti.table_oid + FOR JSON PATH + ), '[]')) AS indexes, + JSON_QUERY(ISNULL(( + SELECT + tri.trigger_name, + tri.trigger_definition, + tri.trigger_enabled_state + FROM triggers_info tri + WHERE tri.table_oid = ti.table_oid + FOR JSON PATH + ), '[]')) AS triggers + FOR JSON PATH, WITHOUT_ARRAY_WRAPPER -- Creates a single JSON object for this table's details + ) + END AS object_details + FROM + table_info ti + ORDER BY + ti.schema_name, ti.table_name; ` func init() { @@ -339,17 +341,17 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() outputFormat, _ := paramsMap["output_format"].(string) if outputFormat != "simple" && outputFormat != "detailed" { - return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat) + return nil, util.NewAgentError(fmt.Sprintf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat), nil) } namedArgs := []any{ @@ -358,14 +360,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } resp, err := source.RunSQL(ctx, listTablesStatement, namedArgs) if err != nil { - return nil, err + return nil, util.ProcessGeneralError(err) } // if there's no results, return empty list instead of null resSlice, ok := resp.([]any) if !ok || len(resSlice) == 0 { return []any{}, nil } - return resp, err + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mssql/mssqlsql/mssqlsql.go b/internal/tools/mssql/mssqlsql/mssqlsql.go index 57b67ec9ac..4e5878e89f 100644 --- a/internal/tools/mssql/mssqlsql/mssqlsql.go +++ b/internal/tools/mssql/mssqlsql/mssqlsql.go @@ -18,12 +18,14 @@ import ( "context" "database/sql" "fmt" + "net/http" "strings" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -94,21 +96,21 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } namedArgs := make([]any, 0, len(newParams)) @@ -123,7 +125,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para namedArgs = append(namedArgs, value) } } - return source.RunSQL(ctx, newStatement, namedArgs) + resp, err := source.RunSQL(ctx, newStatement, namedArgs) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go b/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go index fb4c6a0a97..4363ba2ed7 100644 --- a/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go +++ b/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -89,25 +90,29 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() - sql, ok := paramsMap["sql"].(string) + sqlStr, ok := paramsMap["sql"].(string) if !ok { - return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to get cast %s", paramsMap["sql"]), nil) } // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } - logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql)) - return source.RunSQL(ctx, sql, nil) + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sqlStr)) + resp, err := source.RunSQL(ctx, sqlStr, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go b/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go index d152cc2394..b2e6008af2 100644 --- a/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go +++ b/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go @@ -19,6 +19,7 @@ import ( "database/sql" "encoding/json" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -91,46 +92,46 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() - sql, ok := paramsMap["sql_statement"].(string) + sqlStr, ok := paramsMap["sql_statement"].(string) if !ok { - return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql_statement"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to get cast %s", paramsMap["sql_statement"]), nil) } // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } - logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql)) + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sqlStr)) - query := fmt.Sprintf("EXPLAIN FORMAT=JSON %s", sql) + query := fmt.Sprintf("EXPLAIN FORMAT=JSON %s", sqlStr) result, err := source.RunSQL(ctx, query, nil) if err != nil { - return nil, err + return nil, util.ProcessGeneralError(err) } // extract and return only the query plan object resSlice, ok := result.([]any) if !ok || len(resSlice) == 0 { - return nil, fmt.Errorf("no query plan returned") + return nil, util.NewClientServerError("no query plan returned", http.StatusInternalServerError, nil) } row, ok := resSlice[0].(orderedmap.Row) if !ok || len(row.Columns) == 0 { - return nil, fmt.Errorf("no query plan returned in row") + return nil, util.NewClientServerError("no query plan returned in row", http.StatusInternalServerError, nil) } plan, ok := row.Columns[0].Value.(string) if !ok { - return nil, fmt.Errorf("unable to convert plan object to string") + return nil, util.NewClientServerError("unable to convert plan object to string", http.StatusInternalServerError, nil) } var out map[string]any if err := json.Unmarshal([]byte(plan), &out); err != nil { - return nil, fmt.Errorf("failed to unmarshal query plan json: %w", err) + return nil, util.NewClientServerError("failed to unmarshal query plan json", http.StatusInternalServerError, err) } return out, nil } diff --git a/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go b/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go index d08b57f0ce..3437657da6 100644 --- a/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go +++ b/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -32,65 +33,65 @@ import ( const resourceType string = "mysql-list-active-queries" const listActiveQueriesStatementMySQL = ` - SELECT - p.id AS processlist_id, - substring(IFNULL(p.info, t.trx_query), 1, 100) AS query, - t.trx_started AS trx_started, - (UNIX_TIMESTAMP(UTC_TIMESTAMP()) - UNIX_TIMESTAMP(t.trx_started)) AS trx_duration_seconds, - (UNIX_TIMESTAMP(UTC_TIMESTAMP()) - UNIX_TIMESTAMP(t.trx_wait_started)) AS trx_wait_duration_seconds, - p.time AS query_time, - t.trx_state AS trx_state, - p.state AS process_state, - IF(p.host IS NULL OR p.host = '', p.user, concat(p.user, '@', SUBSTRING_INDEX(p.host, ':', 1))) AS user, - t.trx_rows_locked AS trx_rows_locked, - t.trx_rows_modified AS trx_rows_modified, - p.db AS db - FROM - information_schema.processlist p - LEFT OUTER JOIN - information_schema.innodb_trx t - ON p.id = t.trx_mysql_thread_id - WHERE - (? IS NULL OR p.time >= ?) - AND p.id != CONNECTION_ID() - AND Command NOT IN ('Binlog Dump', 'Binlog Dump GTID', 'Connect', 'Connect Out', 'Register Slave') - AND User NOT IN ('system user', 'event_scheduler') - AND (t.trx_id is NOT NULL OR command != 'Sleep') - ORDER BY - t.trx_started - LIMIT ?; + SELECT + p.id AS processlist_id, + substring(IFNULL(p.info, t.trx_query), 1, 100) AS query, + t.trx_started AS trx_started, + (UNIX_TIMESTAMP(UTC_TIMESTAMP()) - UNIX_TIMESTAMP(t.trx_started)) AS trx_duration_seconds, + (UNIX_TIMESTAMP(UTC_TIMESTAMP()) - UNIX_TIMESTAMP(t.trx_wait_started)) AS trx_wait_duration_seconds, + p.time AS query_time, + t.trx_state AS trx_state, + p.state AS process_state, + IF(p.host IS NULL OR p.host = '', p.user, concat(p.user, '@', SUBSTRING_INDEX(p.host, ':', 1))) AS user, + t.trx_rows_locked AS trx_rows_locked, + t.trx_rows_modified AS trx_rows_modified, + p.db AS db + FROM + information_schema.processlist p + LEFT OUTER JOIN + information_schema.innodb_trx t + ON p.id = t.trx_mysql_thread_id + WHERE + (? IS NULL OR p.time >= ?) + AND p.id != CONNECTION_ID() + AND Command NOT IN ('Binlog Dump', 'Binlog Dump GTID', 'Connect', 'Connect Out', 'Register Slave') + AND User NOT IN ('system user', 'event_scheduler') + AND (t.trx_id is NOT NULL OR command != 'Sleep') + ORDER BY + t.trx_started + LIMIT ?; ` const listActiveQueriesStatementCloudSQLMySQL = ` - SELECT - p.id AS processlist_id, - substring(IFNULL(p.info, t.trx_query), 1, 100) AS query, - t.trx_started AS trx_started, - (UNIX_TIMESTAMP(UTC_TIMESTAMP()) - UNIX_TIMESTAMP(t.trx_started)) AS trx_duration_seconds, - (UNIX_TIMESTAMP(UTC_TIMESTAMP()) - UNIX_TIMESTAMP(t.trx_wait_started)) AS trx_wait_duration_seconds, - p.time AS query_time, - t.trx_state AS trx_state, - p.state AS process_state, - IF(p.host IS NULL OR p.host = '', p.user, concat(p.user, '@', SUBSTRING_INDEX(p.host, ':', 1))) AS user, - t.trx_rows_locked AS trx_rows_locked, - t.trx_rows_modified AS trx_rows_modified, - p.db AS db - FROM - information_schema.processlist p - LEFT OUTER JOIN - information_schema.innodb_trx t - ON p.id = t.trx_mysql_thread_id - WHERE - (? IS NULL OR p.time >= ?) - AND p.id != CONNECTION_ID() - AND SUBSTRING_INDEX(IFNULL(p.host,''), ':', 1) NOT IN ('localhost', '127.0.0.1') - AND IFNULL(p.host,'') NOT LIKE '::1%' - AND Command NOT IN ('Binlog Dump', 'Binlog Dump GTID', 'Connect', 'Connect Out', 'Register Slave') - AND User NOT IN ('system user', 'event_scheduler') - AND (t.trx_id is NOT NULL OR command != 'sleep') - ORDER BY - t.trx_started - LIMIT ?; + SELECT + p.id AS processlist_id, + substring(IFNULL(p.info, t.trx_query), 1, 100) AS query, + t.trx_started AS trx_started, + (UNIX_TIMESTAMP(UTC_TIMESTAMP()) - UNIX_TIMESTAMP(t.trx_started)) AS trx_duration_seconds, + (UNIX_TIMESTAMP(UTC_TIMESTAMP()) - UNIX_TIMESTAMP(t.trx_wait_started)) AS trx_wait_duration_seconds, + p.time AS query_time, + t.trx_state AS trx_state, + p.state AS process_state, + IF(p.host IS NULL OR p.host = '', p.user, concat(p.user, '@', SUBSTRING_INDEX(p.host, ':', 1))) AS user, + t.trx_rows_locked AS trx_rows_locked, + t.trx_rows_modified AS trx_rows_modified, + p.db AS db + FROM + information_schema.processlist p + LEFT OUTER JOIN + information_schema.innodb_trx t + ON p.id = t.trx_mysql_thread_id + WHERE + (? IS NULL OR p.time >= ?) + AND p.id != CONNECTION_ID() + AND SUBSTRING_INDEX(IFNULL(p.host,''), ':', 1) NOT IN ('localhost', '127.0.0.1') + AND IFNULL(p.host,'') NOT LIKE '::1%' + AND Command NOT IN ('Binlog Dump', 'Binlog Dump GTID', 'Connect', 'Connect Out', 'Register Slave') + AND User NOT IN ('system user', 'event_scheduler') + AND (t.trx_id is NOT NULL OR command != 'sleep') + ORDER BY + t.trx_started + LIMIT ?; ` func init() { @@ -177,30 +178,34 @@ type Tool struct { statement string } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() duration, ok := paramsMap["min_duration_secs"].(int) if !ok { - return nil, fmt.Errorf("invalid 'min_duration_secs' parameter; expected an integer") + return nil, util.NewAgentError("invalid 'min_duration_secs' parameter; expected an integer", nil) } limit, ok := paramsMap["limit"].(int) if !ok { - return nil, fmt.Errorf("invalid 'limit' parameter; expected an integer") + return nil, util.NewAgentError("invalid 'limit' parameter; expected an integer", nil) } // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, t.statement)) - return source.RunSQL(ctx, t.statement, []any{duration, duration, limit}) + resp, err := source.RunSQL(ctx, t.statement, []any{duration, duration, limit}) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go b/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go index a6954284e5..4277a6379d 100644 --- a/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go +++ b/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -30,25 +31,25 @@ import ( const resourceType string = "mysql-list-table-fragmentation" const listTableFragmentationStatement = ` - SELECT - table_schema, - table_name, - data_length AS data_size, - index_length AS index_size, - data_free AS data_free, - ROUND((data_free / (data_length + index_length)) * 100, 2) AS fragmentation_percentage - FROM - information_schema.tables - WHERE - table_schema NOT IN ('sys', 'performance_schema', 'mysql', 'information_schema') - AND (COALESCE(?, '') = '' OR table_schema = ?) - AND (COALESCE(?, '') = '' OR table_name = ?) - AND data_free >= ? - ORDER BY - fragmentation_percentage DESC, - table_schema, - table_name - LIMIT ?; + SELECT + table_schema, + table_name, + data_length AS data_size, + index_length AS index_size, + data_free AS data_free, + ROUND((data_free / (data_length + index_length)) * 100, 2) AS fragmentation_percentage + FROM + information_schema.tables + WHERE + table_schema NOT IN ('sys', 'performance_schema', 'mysql', 'information_schema') + AND (COALESCE(?, '') = '' OR table_schema = ?) + AND (COALESCE(?, '') = '' OR table_name = ?) + AND data_free >= ? + ORDER BY + fragmentation_percentage DESC, + table_schema, + table_name + LIMIT ?; ` func init() { @@ -114,39 +115,43 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() table_schema, ok := paramsMap["table_schema"].(string) if !ok { - return nil, fmt.Errorf("invalid 'table_schema' parameter; expected a string") + return nil, util.NewAgentError("invalid 'table_schema' parameter; expected a string", nil) } table_name, ok := paramsMap["table_name"].(string) if !ok { - return nil, fmt.Errorf("invalid 'table_name' parameter; expected a string") + return nil, util.NewAgentError("invalid 'table_name' parameter; expected a string", nil) } data_free_threshold_bytes, ok := paramsMap["data_free_threshold_bytes"].(int) if !ok { - return nil, fmt.Errorf("invalid 'data_free_threshold_bytes' parameter; expected an integer") + return nil, util.NewAgentError("invalid 'data_free_threshold_bytes' parameter; expected an integer", nil) } limit, ok := paramsMap["limit"].(int) if !ok { - return nil, fmt.Errorf("invalid 'limit' parameter; expected an integer") + return nil, util.NewAgentError("invalid 'limit' parameter; expected an integer", nil) } // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, listTableFragmentationStatement)) sliceParams := []any{table_schema, table_schema, table_name, table_name, data_free_threshold_bytes, limit} - return source.RunSQL(ctx, listTableFragmentationStatement, sliceParams) + resp, err := source.RunSQL(ctx, listTableFragmentationStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqllisttables/mysqllisttables.go b/internal/tools/mysql/mysqllisttables/mysqllisttables.go index 9f8879917a..cfca0f87c6 100644 --- a/internal/tools/mysql/mysqllisttables/mysqllisttables.go +++ b/internal/tools/mysql/mysqllisttables/mysqllisttables.go @@ -18,11 +18,13 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -244,32 +246,32 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() tableNames, ok := paramsMap["table_names"].(string) if !ok { - return nil, fmt.Errorf("invalid '%s' parameter; expected a string", tableNames) + return nil, util.NewAgentError(fmt.Sprintf("invalid '%s' parameter; expected a string", tableNames), nil) } outputFormat, _ := paramsMap["output_format"].(string) if outputFormat != "simple" && outputFormat != "detailed" { - return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat) + return nil, util.NewAgentError(fmt.Sprintf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat), nil) } resp, err := source.RunSQL(ctx, listTablesStatement, []any{tableNames, outputFormat}) if err != nil { - return nil, err + return nil, util.ProcessGeneralError(err) } // if there's no results, return empty list instead of null resSlice, ok := resp.([]any) if !ok || len(resSlice) == 0 { return []any{}, nil } - return resp, err + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go b/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go index 5cdeeae61f..50954e6f83 100644 --- a/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go +++ b/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -30,26 +31,26 @@ import ( const resourceType string = "mysql-list-tables-missing-unique-indexes" const listTablesMissingUniqueIndexesStatement = ` - SELECT - tab.table_schema AS table_schema, - tab.table_name AS table_name - FROM - information_schema.tables tab - LEFT JOIN - information_schema.table_constraints tco - ON - tab.table_schema = tco.table_schema - AND tab.table_name = tco.table_name - AND tco.constraint_type IN ('PRIMARY KEY', 'UNIQUE') - WHERE - tco.constraint_type IS NULL - AND tab.table_schema NOT IN('mysql', 'information_schema', 'performance_schema', 'sys') - AND tab.table_type = 'BASE TABLE' - AND (COALESCE(?, '') = '' OR tab.table_schema = ?) - ORDER BY - tab.table_schema, - tab.table_name - LIMIT ?; + SELECT + tab.table_schema AS table_schema, + tab.table_name AS table_name + FROM + information_schema.tables tab + LEFT JOIN + information_schema.table_constraints tco + ON + tab.table_schema = tco.table_schema + AND tab.table_name = tco.table_name + AND tco.constraint_type IN ('PRIMARY KEY', 'UNIQUE') + WHERE + tco.constraint_type IS NULL + AND tab.table_schema NOT IN('mysql', 'information_schema', 'performance_schema', 'sys') + AND tab.table_type = 'BASE TABLE' + AND (COALESCE(?, '') = '' OR tab.table_schema = ?) + ORDER BY + tab.table_schema, + tab.table_name + LIMIT ?; ` func init() { @@ -113,30 +114,34 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() table_schema, ok := paramsMap["table_schema"].(string) if !ok { - return nil, fmt.Errorf("invalid 'table_schema' parameter; expected a string") + return nil, util.NewAgentError("invalid 'table_schema' parameter; expected a string", nil) } limit, ok := paramsMap["limit"].(int) if !ok { - return nil, fmt.Errorf("invalid 'limit' parameter; expected an integer") + return nil, util.NewAgentError("invalid 'limit' parameter; expected an integer", nil) } // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, listTablesMissingUniqueIndexesStatement)) - return source.RunSQL(ctx, listTablesMissingUniqueIndexesStatement, []any{table_schema, table_schema, limit}) + resp, err := source.RunSQL(ctx, listTablesMissingUniqueIndexesStatement, []any{table_schema, table_schema, limit}) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqlsql/mysqlsql.go b/internal/tools/mysql/mysqlsql/mysqlsql.go index 79c0adbaf5..e65e562128 100644 --- a/internal/tools/mysql/mysqlsql/mysqlsql.go +++ b/internal/tools/mysql/mysqlsql/mysqlsql.go @@ -18,11 +18,13 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -93,25 +95,29 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, newStatement, sliceParams) + resp, err := source.RunSQL(ctx, newStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/neo4j/neo4jcypher/neo4jcypher.go b/internal/tools/neo4j/neo4jcypher/neo4jcypher.go index 3c9459ff63..fc4cb89f1b 100644 --- a/internal/tools/neo4j/neo4jcypher/neo4jcypher.go +++ b/internal/tools/neo4j/neo4jcypher/neo4jcypher.go @@ -17,12 +17,14 @@ package neo4jcypher import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -85,14 +87,18 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() - return source.RunQuery(ctx, t.Statement, paramsMap, false, false) + resp, err := source.RunQuery(ctx, t.Statement, paramsMap, false, false) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go b/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go index ef32d1c6e7..2ea2fd9681 100644 --- a/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go +++ b/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go @@ -17,11 +17,13 @@ package neo4jexecutecypher import ( "context" "fmt" + "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -94,28 +96,32 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() cypherStr, ok := paramsMap["cypher"].(string) if !ok { - return nil, fmt.Errorf("unable to cast cypher parameter %s", paramsMap["cypher"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to cast cypher parameter %s", paramsMap["cypher"]), nil) } if cypherStr == "" { - return nil, fmt.Errorf("parameter 'cypher' must be a non-empty string") + return nil, util.NewAgentError("parameter 'cypher' must be a non-empty string", nil) } dryRun, ok := paramsMap["dry_run"].(bool) if !ok { - return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to cast dry_run parameter %s", paramsMap["dry_run"]), nil) } - return source.RunQuery(ctx, cypherStr, nil, t.ReadOnly, dryRun) + resp, err := source.RunQuery(ctx, cypherStr, nil, t.ReadOnly, dryRun) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/neo4j/neo4jschema/neo4jschema.go b/internal/tools/neo4j/neo4jschema/neo4jschema.go index 441a8eaa72..9f217a2502 100644 --- a/internal/tools/neo4j/neo4jschema/neo4jschema.go +++ b/internal/tools/neo4j/neo4jschema/neo4jschema.go @@ -17,6 +17,7 @@ package neo4jschema import ( "context" "fmt" + "net/http" "sync" "time" @@ -27,6 +28,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/cache" "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/helpers" "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/types" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/neo4j/neo4j-go-driver/v5/neo4j" ) @@ -113,10 +115,10 @@ type Tool struct { // Invoke executes the tool's main logic: fetching the Neo4j schema. // It first checks the cache for a valid schema before extracting it from the database. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } // Check if a valid schema is already in the cache. @@ -129,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // If not cached, extract the schema from the database. schema, err := t.extractSchema(ctx, source) if err != nil { - return nil, fmt.Errorf("failed to extract database schema: %w", err) + return nil, util.ProcessGeneralError(err) } // Cache the newly extracted schema for future use. @@ -372,14 +374,14 @@ func (t Tool) GetAPOCSchema(ctx context.Context, source compatibleSource) ([]typ name: "apoc-relationships", fn: func(session neo4j.SessionWithContext) error { query := ` - MATCH (startNode)-[rel]->(endNode) - WITH - labels(startNode)[0] AS startNode, - type(rel) AS relType, - apoc.meta.cypher.types(rel) AS relProperties, - labels(endNode)[0] AS endNode, - count(*) AS count - RETURN relType, startNode, endNode, relProperties, count` + MATCH (startNode)-[rel]->(endNode) + WITH + labels(startNode)[0] AS startNode, + type(rel) AS relType, + apoc.meta.cypher.types(rel) AS relProperties, + labels(endNode)[0] AS endNode, + count(*) AS count + RETURN relType, startNode, endNode, relProperties, count` result, err := session.Run(ctx, query, nil) if err != nil { return fmt.Errorf("failed to extract relationships: %w", err) @@ -520,10 +522,10 @@ func (t Tool) GetSchemaWithoutAPOC(ctx context.Context, source compatibleSource, name: "relationship-schema", fn: func(session neo4j.SessionWithContext) error { relQuery := ` - MATCH (start)-[r]->(end) - WITH type(r) AS relType, labels(start) AS startLabels, labels(end) AS endLabels, count(*) AS count - RETURN relType, CASE WHEN size(startLabels) > 0 THEN startLabels[0] ELSE null END AS startLabel, CASE WHEN size(endLabels) > 0 THEN endLabels[0] ELSE null END AS endLabel, sum(count) AS totalCount - ORDER BY totalCount DESC` + MATCH (start)-[r]->(end) + WITH type(r) AS relType, labels(start) AS startLabels, labels(end) AS endLabels, count(*) AS count + RETURN relType, CASE WHEN size(startLabels) > 0 THEN startLabels[0] ELSE null END AS startLabel, CASE WHEN size(endLabels) > 0 THEN endLabels[0] ELSE null END AS endLabel, sum(count) AS totalCount + ORDER BY totalCount DESC` relResult, err := session.Run(ctx, relQuery, nil) if err != nil { return fmt.Errorf("relationship count query failed: %w", err) diff --git a/internal/tools/oceanbase/oceanbaseexecutesql/oceanbaseexecutesql.go b/internal/tools/oceanbase/oceanbaseexecutesql/oceanbaseexecutesql.go index 1987f24d45..173199daea 100644 --- a/internal/tools/oceanbase/oceanbaseexecutesql/oceanbaseexecutesql.go +++ b/internal/tools/oceanbase/oceanbaseexecutesql/oceanbaseexecutesql.go @@ -18,11 +18,13 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -89,18 +91,22 @@ type Tool struct { } // Invoke executes the SQL statement provided in the parameters. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } sliceParams := params.AsSlice() sqlStr, ok := sliceParams[0].(string) if !ok { - return nil, fmt.Errorf("unable to get cast %s", sliceParams[0]) + return nil, util.NewAgentError(fmt.Sprintf("unable to get cast %s", sliceParams[0]), nil) } - return source.RunSQL(ctx, sqlStr, nil) + resp, err := source.RunSQL(ctx, sqlStr, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/oceanbase/oceanbasesql/oceanbasesql.go b/internal/tools/oceanbase/oceanbasesql/oceanbasesql.go index 0b8a7421d3..ddcc83fbc5 100644 --- a/internal/tools/oceanbase/oceanbasesql/oceanbasesql.go +++ b/internal/tools/oceanbase/oceanbasesql/oceanbasesql.go @@ -18,11 +18,13 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -94,24 +96,28 @@ type Tool struct { } // Invoke executes the SQL statement with the provided parameters. -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, newStatement, sliceParams) + resp, err := source.RunSQL(ctx, newStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go b/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go index c91b3bcc06..1f7a047681 100644 --- a/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go +++ b/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go @@ -6,6 +6,7 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -77,25 +78,29 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() sqlParam, ok := paramsMap["sql"].(string) if !ok { - return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to get cast %s", paramsMap["sql"]), nil) } // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } logger.DebugContext(ctx, "executing `%s` tool query: %s", resourceType, sqlParam) - return source.RunSQL(ctx, sqlParam, nil) + resp, err := source.RunSQL(ctx, sqlParam, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/oracle/oraclesql/oraclesql.go b/internal/tools/oracle/oraclesql/oraclesql.go index 347b18d41b..84041ce6b1 100644 --- a/internal/tools/oracle/oraclesql/oraclesql.go +++ b/internal/tools/oracle/oraclesql/oraclesql.go @@ -6,11 +6,13 @@ import ( "context" "database/sql" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -81,21 +83,21 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() @@ -103,7 +105,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para fmt.Printf("[%d]=%T ", i, p) } fmt.Printf("\n") - return source.RunSQL(ctx, newStatement, sliceParams) + resp, err := source.RunSQL(ctx, newStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgresdatabaseoverview/postgresdatabaseoverview.go b/internal/tools/postgres/postgresdatabaseoverview/postgresdatabaseoverview.go index 4d48cbc6cb..e621a142ff 100644 --- a/internal/tools/postgres/postgresdatabaseoverview/postgresdatabaseoverview.go +++ b/internal/tools/postgres/postgresdatabaseoverview/postgresdatabaseoverview.go @@ -17,11 +17,13 @@ package postgresdatabaseoverview import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,13 +31,13 @@ import ( const resourceType string = "postgres-database-overview" const databaseOverviewStatement = ` - SELECT + SELECT current_setting('server_version') AS pg_version, pg_is_in_recovery() AS is_replica, (now() - pg_postmaster_start_time())::TEXT AS uptime, current_setting('max_connections')::int AS max_connections, - (SELECT count(*) FROM pg_stat_activity) AS current_connections, - (SELECT count(*) FROM pg_stat_activity WHERE state = 'active') AS active_connections, + (SELECT count(*) FROM pg_stat_activity) AS current_connections, + (SELECT count(*) FROM pg_stat_activity WHERE state = 'active') AS active_connections, round( (100.0 * (SELECT count(*) FROM pg_stat_activity) / current_setting('max_connections')::int), 2 @@ -57,7 +59,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T } type compatibleSource interface { - PostgresPool() *pgxpool.Pool // keep this so that sources are postgres compatible + PostgresPool() *pgxpool.Pool RunSQL(context.Context, string, []any) (any, error) } @@ -69,7 +71,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -83,7 +84,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -96,7 +96,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -110,20 +109,24 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, databaseOverviewStatement, sliceParams) + resp, err := source.RunSQL(ctx, databaseOverviewStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go b/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go index 57e0c8fce4..7b81f9bfce 100644 --- a/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go +++ b/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go @@ -17,6 +17,7 @@ package postgresexecutesql import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -56,7 +57,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -69,7 +69,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) - // finish tool setup t := Tool{ Config: cfg, Parameters: params, @@ -79,7 +78,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return t, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -89,25 +87,28 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { - return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"]) + return nil, util.NewAgentError(fmt.Sprintf("unable to get cast %s", paramsMap["sql"]), nil) } - // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { - return nil, fmt.Errorf("error getting logger: %s", err) + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql)) - return source.RunSQL(ctx, sql, nil) + resp, err := source.RunSQL(ctx, sql, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgresgetcolumncardinality/postgresgetcolumncardinality.go b/internal/tools/postgres/postgresgetcolumncardinality/postgresgetcolumncardinality.go index 81cc92673e..b4358f439f 100644 --- a/internal/tools/postgres/postgresgetcolumncardinality/postgresgetcolumncardinality.go +++ b/internal/tools/postgres/postgresgetcolumncardinality/postgresgetcolumncardinality.go @@ -17,11 +17,13 @@ package postgresgetcolumncardinality import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,7 +31,7 @@ import ( const resourceType string = "postgres-get-column-cardinality" const getColumnCardinality = ` - SELECT + SELECT s.attname AS column_name, ROUND( CASE @@ -74,7 +76,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -95,7 +96,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -108,7 +108,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -122,20 +121,24 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, getColumnCardinality, sliceParams) + resp, err := source.RunSQL(ctx, getColumnCardinality, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistactivequeries/postgreslistactivequeries.go b/internal/tools/postgres/postgreslistactivequeries/postgreslistactivequeries.go index a7b1f7587d..ab4b36c3a3 100644 --- a/internal/tools/postgres/postgreslistactivequeries/postgreslistactivequeries.go +++ b/internal/tools/postgres/postgreslistactivequeries/postgreslistactivequeries.go @@ -17,11 +17,13 @@ package postgreslistactivequeries import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,26 +31,26 @@ import ( const resourceType string = "postgres-list-active-queries" const listActiveQueriesStatement = ` - SELECT - pid, - usename AS user, - datname, - application_name, - client_addr, - state, - wait_event_type, - wait_event, - backend_start, - xact_start, - query_start, - now() - query_start AS query_duration, - query - FROM pg_stat_activity - WHERE state = 'active' - AND ($1::INTERVAL IS NULL OR now() - query_start >= $1::INTERVAL) - AND ($2::text IS NULL OR application_name NOT IN (SELECT trim(app) FROM unnest(string_to_array($2, ',')) AS app)) - ORDER BY query_duration DESC - LIMIT COALESCE($3::int, 50); + SELECT + pid, + usename AS user, + datname, + application_name, + client_addr, + state, + wait_event_type, + wait_event, + backend_start, + xact_start, + query_start, + now() - query_start AS query_duration, + query + FROM pg_stat_activity + WHERE state = 'active' + AND ($1::INTERVAL IS NULL OR now() - query_start >= $1::INTERVAL) + AND ($2::text IS NULL OR application_name NOT IN (SELECT trim(app) FROM unnest(string_to_array($2, ',')) AS app)) + ORDER BY query_duration DESC + LIMIT COALESCE($3::int, 50); ` func init() { @@ -78,7 +80,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -94,8 +95,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) paramManifest := allParameters.Manifest() mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup - t := Tool{ + return Tool{ Config: cfg, allParams: allParameters, manifest: tools.Manifest{ @@ -104,11 +104,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) AuthRequired: cfg.AuthRequired, }, mcpManifest: mcpManifest, - } - return t, nil + }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -118,21 +116,25 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listActiveQueriesStatement, sliceParams) + resp, err := source.RunSQL(ctx, listActiveQueriesStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistavailableextensions/postgreslistavailableextensions.go b/internal/tools/postgres/postgreslistavailableextensions/postgreslistavailableextensions.go index 489df27583..6ecf06509d 100644 --- a/internal/tools/postgres/postgreslistavailableextensions/postgreslistavailableextensions.go +++ b/internal/tools/postgres/postgreslistavailableextensions/postgreslistavailableextensions.go @@ -17,11 +17,13 @@ package postgreslistavailableextensions import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,13 +31,13 @@ import ( const resourceType string = "postgres-list-available-extensions" const listAvailableExtensionsQuery = ` - SELECT - name, - default_version, - comment as description - FROM - pg_available_extensions - ORDER BY name; + SELECT + name, + default_version, + comment as description + FROM + pg_available_extensions + ORDER BY name; ` func init() { @@ -65,7 +67,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -76,7 +77,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) params := parameters.Parameters{} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) - // finish tool setup t := Tool{ Config: cfg, manifest: tools.Manifest{ @@ -90,7 +90,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return t, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -100,12 +99,16 @@ type Tool struct { Parameters parameters.Parameters } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } - return source.RunSQL(ctx, listAvailableExtensionsQuery, nil) + resp, err := source.RunSQL(ctx, listAvailableExtensionsQuery, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats.go b/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats.go index c78dd297d7..f01fc002a6 100644 --- a/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats.go +++ b/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats.go @@ -17,82 +17,83 @@ package postgreslistdatabasestats import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) const resourceType string = "postgres-list-database-stats" -// SQL query to list database statistics const listDatabaseStats = ` - WITH database_stats AS ( - SELECT - s.datname AS database_name, - -- Database Metadata - d.datallowconn AS is_connectable, - pg_get_userbyid(d.datdba) AS database_owner, - ts.spcname AS default_tablespace, + WITH database_stats AS ( + SELECT + s.datname AS database_name, + -- Database Metadata + d.datallowconn AS is_connectable, + pg_get_userbyid(d.datdba) AS database_owner, + ts.spcname AS default_tablespace, - -- Cache Performance - CASE - WHEN (s.blks_hit + s.blks_read) = 0 THEN 0 - ELSE round((s.blks_hit * 100.0) / (s.blks_hit + s.blks_read), 2) - END AS cache_hit_ratio_percent, - s.blks_read AS blocks_read_from_disk, - s.blks_hit AS blocks_hit_in_cache, + -- Cache Performance + CASE + WHEN (s.blks_hit + s.blks_read) = 0 THEN 0 + ELSE round((s.blks_hit * 100.0) / (s.blks_hit + s.blks_read), 2) + END AS cache_hit_ratio_percent, + s.blks_read AS blocks_read_from_disk, + s.blks_hit AS blocks_hit_in_cache, - -- Transaction Throughput - s.xact_commit, - s.xact_rollback, - round(s.xact_rollback * 100.0 / (s.xact_commit + s.xact_rollback + 1), 2) AS rollback_ratio_percent, + -- Transaction Throughput + s.xact_commit, + s.xact_rollback, + round(s.xact_rollback * 100.0 / (s.xact_commit + s.xact_rollback + 1), 2) AS rollback_ratio_percent, - -- Tuple Activity - s.tup_returned AS rows_returned_by_queries, - s.tup_fetched AS rows_fetched_by_scans, - s.tup_inserted, - s.tup_updated, - s.tup_deleted, + -- Tuple Activity + s.tup_returned AS rows_returned_by_queries, + s.tup_fetched AS rows_fetched_by_scans, + s.tup_inserted, + s.tup_updated, + s.tup_deleted, - -- Temporary File Usage - s.temp_files, - s.temp_bytes AS temp_size_bytes, + -- Temporary File Usage + s.temp_files, + s.temp_bytes AS temp_size_bytes, - -- Conflicts & Deadlocks - s.conflicts, - s.deadlocks, + -- Conflicts & Deadlocks + s.conflicts, + s.deadlocks, - -- General Info - s.numbackends AS active_connections, - s.stats_reset AS statistics_last_reset, - pg_database_size(s.datid) AS database_size_bytes - FROM - pg_stat_database s - JOIN - pg_database d ON d.oid = s.datid - JOIN - pg_tablespace ts ON ts.oid = d.dattablespace - WHERE - -- Exclude cloudsql internal databases - s.datname NOT IN ('cloudsqladmin') - -- Exclude template databases if not requested - AND ( $2::boolean IS TRUE OR d.datistemplate IS FALSE ) - ) - SELECT * - FROM database_stats - WHERE - ($1::text IS NULL OR database_name LIKE '%' || $1::text || '%') - AND ($3::text IS NULL OR database_owner LIKE '%' || $3::text || '%') - AND ($4::text IS NULL OR default_tablespace LIKE '%' || $4::text || '%') - ORDER BY - CASE WHEN $5::text = 'size' THEN database_size_bytes END DESC, - CASE WHEN $5::text = 'commit' THEN xact_commit END DESC, - database_name - LIMIT COALESCE($6::int, 10); + -- General Info + s.numbackends AS active_connections, + s.stats_reset AS statistics_last_reset, + pg_database_size(s.datid) AS database_size_bytes + FROM + pg_stat_database s + JOIN + pg_database d ON d.oid = s.datid + JOIN + pg_tablespace ts ON ts.oid = d.dattablespace + WHERE + -- Exclude cloudsql internal databases + s.datname NOT IN ('cloudsqladmin') + -- Exclude template databases if not requested + AND ( $2::boolean IS TRUE OR d.datistemplate IS FALSE ) + ) + SELECT * + FROM database_stats + WHERE + ($1::text IS NULL OR database_name LIKE '%' || $1::text || '%') + AND ($3::text IS NULL OR database_owner LIKE '%' || $3::text || '%') + AND ($4::text IS NULL OR default_tablespace LIKE '%' || $4::text || '%') + ORDER BY + CASE WHEN $5::text = 'size' THEN database_size_bytes END DESC, + CASE WHEN $5::text = 'commit' THEN xact_commit END DESC, + database_name + LIMIT COALESCE($6::int, 10); ` func init() { @@ -122,7 +123,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -164,7 +164,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -177,7 +176,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -187,21 +185,25 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listDatabaseStats, sliceParams) + resp, err := source.RunSQL(ctx, listDatabaseStats, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistindexes/postgreslistindexes.go b/internal/tools/postgres/postgreslistindexes/postgreslistindexes.go index debd2d8036..10f8b92327 100644 --- a/internal/tools/postgres/postgreslistindexes/postgreslistindexes.go +++ b/internal/tools/postgres/postgreslistindexes/postgreslistindexes.go @@ -17,11 +17,13 @@ package postgreslistindexes import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,49 +31,49 @@ import ( const resourceType string = "postgres-list-indexes" const listIndexesStatement = ` - WITH IndexDetails AS ( - SELECT - s.schemaname AS schema_name, - t.relname AS table_name, - i.relname AS index_name, - am.amname AS index_type, - ix.indisunique AS is_unique, - ix.indisprimary AS is_primary, - pg_get_indexdef(i.oid) AS index_definition, - pg_relation_size(i.oid) AS index_size_bytes, - s.idx_scan AS index_scans, - s.idx_tup_read AS tuples_read, - s.idx_tup_fetch AS tuples_fetched, - CASE - WHEN s.idx_scan > 0 THEN true - ELSE false - END AS is_used - FROM pg_catalog.pg_class t - JOIN pg_catalog.pg_index ix - ON t.oid = ix.indrelid - JOIN pg_catalog.pg_class i - ON i.oid = ix.indexrelid - JOIN pg_catalog.pg_am am - ON i.relam = am.oid - JOIN pg_catalog.pg_stat_all_indexes s - ON i.oid = s.indexrelid - WHERE - t.relkind = 'r' - AND s.schemaname NOT IN ('pg_catalog', 'information_schema', 'pg_toast') - AND s.schemaname NOT LIKE 'pg_temp_%' - ) - SELECT * - FROM IndexDetails - WHERE - ($1::text IS NULL OR schema_name LIKE '%' || $1 || '%') - AND ($2::text IS NULL OR table_name LIKE '%' || $2 || '%') - AND ($3::text IS NULL OR index_name LIKE '%' || $3 || '%') - AND ($4::boolean IS NOT TRUE OR is_used IS FALSE) - ORDER BY - schema_name, - table_name, - index_name - LIMIT COALESCE($5::int, 50); + WITH IndexDetails AS ( + SELECT + s.schemaname AS schema_name, + t.relname AS table_name, + i.relname AS index_name, + am.amname AS index_type, + ix.indisunique AS is_unique, + ix.indisprimary AS is_primary, + pg_get_indexdef(i.oid) AS index_definition, + pg_relation_size(i.oid) AS index_size_bytes, + s.idx_scan AS index_scans, + s.idx_tup_read AS tuples_read, + s.idx_tup_fetch AS tuples_fetched, + CASE + WHEN s.idx_scan > 0 THEN true + ELSE false + END AS is_used + FROM pg_catalog.pg_class t + JOIN pg_catalog.pg_index ix + ON t.oid = ix.indrelid + JOIN pg_catalog.pg_class i + ON i.oid = ix.indexrelid + JOIN pg_catalog.pg_am am + ON i.relam = am.oid + JOIN pg_catalog.pg_stat_all_indexes s + ON i.oid = s.indexrelid + WHERE + t.relkind = 'r' + AND s.schemaname NOT IN ('pg_catalog', 'information_schema', 'pg_toast') + AND s.schemaname NOT LIKE 'pg_temp_%' + ) + SELECT * + FROM IndexDetails + WHERE + ($1::text IS NULL OR schema_name LIKE '%' || $1 || '%') + AND ($2::text IS NULL OR table_name LIKE '%' || $2 || '%') + AND ($3::text IS NULL OR index_name LIKE '%' || $3 || '%') + AND ($4::boolean IS NOT TRUE OR is_used IS FALSE) + ORDER BY + schema_name, + table_name, + index_name + LIMIT COALESCE($5::int, 50); ` func init() { @@ -101,7 +103,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -122,7 +123,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -135,7 +135,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -149,21 +148,25 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listIndexesStatement, sliceParams) + resp, err := source.RunSQL(ctx, listIndexesStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistinstalledextensions/postgreslistinstalledextensions.go b/internal/tools/postgres/postgreslistinstalledextensions/postgreslistinstalledextensions.go index 8273ae9247..cdac40ab0e 100644 --- a/internal/tools/postgres/postgreslistinstalledextensions/postgreslistinstalledextensions.go +++ b/internal/tools/postgres/postgreslistinstalledextensions/postgreslistinstalledextensions.go @@ -17,11 +17,13 @@ package postgreslistinstalledextensions import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,24 +31,24 @@ import ( const resourceType string = "postgres-list-installed-extensions" const listAvailableExtensionsQuery = ` - SELECT - e.extname AS name, - e.extversion AS version, - n.nspname AS schema, - pg_get_userbyid(e.extowner) AS owner, - c.description AS description - FROM - pg_catalog.pg_extension e - LEFT JOIN - pg_catalog.pg_namespace n - ON - n.oid = e.extnamespace - LEFT JOIN - pg_catalog.pg_description c - ON - c.objoid = e.oid - AND c.classoid = 'pg_catalog.pg_extension'::pg_catalog.regclass - ORDER BY 1; + SELECT + e.extname AS name, + e.extversion AS version, + n.nspname AS schema, + pg_get_userbyid(e.extowner) AS owner, + c.description AS description + FROM + pg_catalog.pg_extension e + LEFT JOIN + pg_catalog.pg_namespace n + ON + n.oid = e.extnamespace + LEFT JOIN + pg_catalog.pg_description c + ON + c.objoid = e.oid + AND c.classoid = 'pg_catalog.pg_extension'::pg_catalog.regclass + ORDER BY 1; ` func init() { @@ -76,7 +78,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -87,7 +88,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) params := parameters.Parameters{} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) - // finish tool setup t := Tool{ Config: cfg, manifest: tools.Manifest{ @@ -100,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return t, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -109,12 +108,16 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } - return source.RunSQL(ctx, listAvailableExtensionsQuery, nil) + resp, err := source.RunSQL(ctx, listAvailableExtensionsQuery, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { @@ -145,7 +148,6 @@ func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, return "Authorization", nil } -// This tool does not have parameters, so return an empty set. func (t Tool) GetParameters() parameters.Parameters { return parameters.Parameters{} } diff --git a/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go b/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go index b29cb1e57e..bac4b6a01b 100644 --- a/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go +++ b/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go @@ -17,11 +17,13 @@ package postgreslistlocks import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,7 +31,7 @@ import ( const resourceType string = "postgres-list-locks" const listLocks = ` - SELECT + SELECT locked.pid, locked.usename, locked.query, @@ -76,7 +78,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -93,7 +94,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -106,7 +106,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -120,21 +119,25 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listLocks, sliceParams) + resp, err := source.RunSQL(ctx, listLocks, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings.go b/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings.go index 6d10837830..85d9dd0e35 100644 --- a/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings.go +++ b/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings.go @@ -17,11 +17,13 @@ package postgreslistpgsettings import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,7 +31,7 @@ import ( const resourceType string = "postgres-list-pg-settings" const listPgSettingsStatement = ` - SELECT + SELECT name, setting AS current_value, unit, @@ -41,10 +43,10 @@ const listPgSettingsStatement = ` ELSE 'No' END AS requires_restart - FROM pg_settings - WHERE ($1::text IS NULL OR name LIKE '%' || $1::text || '%') - ORDER BY name - LIMIT COALESCE($2::int, 50); + FROM pg_settings + WHERE ($1::text IS NULL OR name LIKE '%' || $1::text || '%') + ORDER BY name + LIMIT COALESCE($2::int, 50); ` func init() { @@ -74,7 +76,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -92,7 +93,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -105,7 +105,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -115,19 +114,23 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listPgSettingsStatement, sliceParams) + resp, err := source.RunSQL(ctx, listPgSettingsStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables.go b/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables.go index db6af2c62f..a5ee63db16 100644 --- a/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables.go +++ b/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables.go @@ -17,11 +17,13 @@ package postgreslistpublicationtables import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,33 +31,33 @@ import ( const resourceType string = "postgres-list-publication-tables" const listPublicationTablesStatement = ` - WITH - publication_details AS ( - SELECT - pt.pubname AS publication_name, - pt.schemaname AS schema_name, - pt.tablename AS table_name, - -- Definition details - p.puballtables AS publishes_all_tables, - p.pubinsert AS publishes_inserts, - p.pubupdate AS publishes_updates, - p.pubdelete AS publishes_deletes, - p.pubtruncate AS publishes_truncates, - -- Owner information - pg_catalog.pg_get_userbyid(p.pubowner) AS publication_owner - FROM pg_catalog.pg_publication_tables pt - JOIN pg_catalog.pg_publication p - ON pt.pubname = p.pubname - ) - SELECT * - FROM publication_details - WHERE - (NULLIF(TRIM($1::text), '') IS NULL OR table_name = ANY(regexp_split_to_array(TRIM($1::text), '\s*,\s*'))) - AND (NULLIF(TRIM($2::text), '') IS NULL OR publication_name = ANY(regexp_split_to_array(TRIM($2::text), '\s*,\s*'))) - AND (NULLIF(TRIM($3::text), '') IS NULL OR schema_name = ANY(regexp_split_to_array(TRIM($3::text), '\s*,\s*'))) - ORDER BY - publication_name, schema_name, table_name - LIMIT COALESCE($4::int, 50); + WITH + publication_details AS ( + SELECT + pt.pubname AS publication_name, + pt.schemaname AS schema_name, + pt.tablename AS table_name, + -- Definition details + p.puballtables AS publishes_all_tables, + p.pubinsert AS publishes_inserts, + p.pubupdate AS publishes_updates, + p.pubdelete AS publishes_deletes, + p.pubtruncate AS publishes_truncates, + -- Owner information + pg_catalog.pg_get_userbyid(p.pubowner) AS publication_owner + FROM pg_catalog.pg_publication_tables pt + JOIN pg_catalog.pg_publication p + ON pt.pubname = p.pubname + ) + SELECT * + FROM publication_details + WHERE + (NULLIF(TRIM($1::text), '') IS NULL OR table_name = ANY(regexp_split_to_array(TRIM($1::text), '\s*,\s*'))) + AND (NULLIF(TRIM($2::text), '') IS NULL OR publication_name = ANY(regexp_split_to_array(TRIM($2::text), '\s*,\s*'))) + AND (NULLIF(TRIM($3::text), '') IS NULL OR schema_name = ANY(regexp_split_to_array(TRIM($3::text), '\s*,\s*'))) + ORDER BY + publication_name, schema_name, table_name + LIMIT COALESCE($4::int, 50); ` func init() { @@ -85,7 +87,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -105,7 +106,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -118,7 +118,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -128,20 +127,24 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listPublicationTablesStatement, sliceParams) + resp, err := source.RunSQL(ctx, listPublicationTablesStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistquerystats/postgreslistquerystats.go b/internal/tools/postgres/postgreslistquerystats/postgreslistquerystats.go index 20303abfc3..f54c3dc554 100644 --- a/internal/tools/postgres/postgreslistquerystats/postgreslistquerystats.go +++ b/internal/tools/postgres/postgreslistquerystats/postgreslistquerystats.go @@ -17,11 +17,13 @@ package postgreslistquerystats import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,7 +31,7 @@ import ( const resourceType string = "postgres-list-query-stats" const listQueryStats = ` - SELECT + SELECT d.datname, s.query, s.calls, @@ -75,7 +77,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -95,7 +96,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -108,7 +108,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -122,19 +121,23 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listQueryStats, sliceParams) + resp, err := source.RunSQL(ctx, listQueryStats, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistroles/postgreslistroles.go b/internal/tools/postgres/postgreslistroles/postgreslistroles.go index c14b652c58..20cf87c20f 100644 --- a/internal/tools/postgres/postgreslistroles/postgreslistroles.go +++ b/internal/tools/postgres/postgreslistroles/postgreslistroles.go @@ -17,11 +17,13 @@ package postgreslistroles import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,45 +31,45 @@ import ( const resourceType string = "postgres-list-roles" const listRolesStatement = ` - WITH RoleDetails AS ( - SELECT - r.rolname AS role_name, - r.oid AS oid, - r.rolconnlimit AS connection_limit, - r.rolsuper AS is_superuser, - r.rolinherit AS inherits_privileges, - r.rolcreaterole AS can_create_roles, - r.rolcreatedb AS can_create_db, - r.rolcanlogin AS can_login, - r.rolreplication AS is_replication_role, - r.rolbypassrls AS bypass_rls, - r.rolvaliduntil AS valid_until, - -- List of roles that belong to this role (Direct Members) - ARRAY( - SELECT m_r.rolname - FROM pg_auth_members pam - JOIN pg_roles m_r ON pam.member = m_r.oid - WHERE pam.roleid = r.oid - ) AS direct_members, - -- List of roles that this role belongs to (Member Of) - ARRAY( - SELECT g_r.rolname - FROM pg_auth_members pam - JOIN pg_roles g_r ON pam.roleid = g_r.oid - WHERE pam.member = r.oid - ) AS member_of - FROM pg_roles r - -- Exclude system and internal roles - WHERE r.rolname NOT LIKE 'cloudsql%' - AND r.rolname NOT LIKE 'alloydb_%' - AND r.rolname NOT LIKE 'pg_%' - ) - SELECT * - FROM RoleDetails - WHERE - ($1::text IS NULL OR role_name LIKE '%' || $1 || '%') - ORDER BY role_name - LIMIT COALESCE($2::int, 50); + WITH RoleDetails AS ( + SELECT + r.rolname AS role_name, + r.oid AS oid, + r.rolconnlimit AS connection_limit, + r.rolsuper AS is_superuser, + r.rolinherit AS inherits_privileges, + r.rolcreaterole AS can_create_roles, + r.rolcreatedb AS can_create_db, + r.rolcanlogin AS can_login, + r.rolreplication AS is_replication_role, + r.rolbypassrls AS bypass_rls, + r.rolvaliduntil AS valid_until, + -- List of roles that belong to this role (Direct Members) + ARRAY( + SELECT m_r.rolname + FROM pg_auth_members pam + JOIN pg_roles m_r ON pam.member = m_r.oid + WHERE pam.roleid = r.oid + ) AS direct_members, + -- List of roles that this role belongs to (Member Of) + ARRAY( + SELECT g_r.rolname + FROM pg_auth_members pam + JOIN pg_roles g_r ON pam.roleid = g_r.oid + WHERE pam.member = r.oid + ) AS member_of + FROM pg_roles r + -- Exclude system and internal roles + WHERE r.rolname NOT LIKE 'cloudsql%' + AND r.rolname NOT LIKE 'alloydb_%' + AND r.rolname NOT LIKE 'pg_%' + ) + SELECT * + FROM RoleDetails + WHERE + ($1::text IS NULL OR role_name LIKE '%' || $1 || '%') + ORDER BY role_name + LIMIT COALESCE($2::int, 50); ` func init() { @@ -97,7 +99,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -116,7 +117,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -129,7 +129,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -143,20 +142,24 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listRolesStatement, sliceParams) + resp, err := source.RunSQL(ctx, listRolesStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistschemas/postgreslistschemas.go b/internal/tools/postgres/postgreslistschemas/postgreslistschemas.go index dbf2a8b367..b1ff208f08 100644 --- a/internal/tools/postgres/postgreslistschemas/postgreslistschemas.go +++ b/internal/tools/postgres/postgreslistschemas/postgreslistschemas.go @@ -17,11 +17,13 @@ package postgreslistschemas import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,7 +31,7 @@ import ( const resourceType string = "postgres-list-schemas" const listSchemasStatement = ` - WITH + WITH schema_grants AS ( SELECT schema_oid, jsonb_object_agg(grantee, privileges) AS grants FROM @@ -52,27 +54,27 @@ const listSchemasStatement = ` SELECT n.nspname AS schema_name, pg_catalog.pg_get_userbyid(n.nspowner) AS owner, - COALESCE(sg.grants, '{}'::jsonb) AS grants, - ( - SELECT COUNT(*) - FROM pg_catalog.pg_class c - WHERE c.relnamespace = n.oid AND c.relkind = 'r' - ) AS tables, - ( - SELECT COUNT(*) - FROM pg_catalog.pg_class c - WHERE c.relnamespace = n.oid AND c.relkind = 'v' - ) AS views, - (SELECT COUNT(*) FROM pg_catalog.pg_proc p WHERE p.pronamespace = n.oid) - AS functions + COALESCE(sg.grants, '{}'::jsonb) AS grants, + ( + SELECT COUNT(*) + FROM pg_catalog.pg_class c + WHERE c.relnamespace = n.oid AND c.relkind = 'r' + ) AS tables, + ( + SELECT COUNT(*) + FROM pg_catalog.pg_class c + WHERE c.relnamespace = n.oid AND c.relkind = 'v' + ) AS views, + (SELECT COUNT(*) FROM pg_catalog.pg_proc p WHERE p.pronamespace = n.oid) + AS functions FROM pg_catalog.pg_namespace n LEFT JOIN schema_grants sg ON n.oid = sg.schema_oid ) - SELECT * - FROM all_schemas - -- Exclude system schemas and temporary schemas created per session. - WHERE + SELECT * + FROM all_schemas + -- Exclude system and temporary schemas created per session. + WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast') AND schema_name NOT LIKE 'pg_temp_%' AND schema_name NOT LIKE 'pg_toast_temp_%' @@ -109,7 +111,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -128,7 +129,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -141,7 +141,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -151,20 +150,24 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listSchemasStatement, sliceParams) + resp, err := source.RunSQL(ctx, listSchemasStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistsequences/postgreslistsequences.go b/internal/tools/postgres/postgreslistsequences/postgreslistsequences.go index bee44edbca..aca352317c 100644 --- a/internal/tools/postgres/postgreslistsequences/postgreslistsequences.go +++ b/internal/tools/postgres/postgreslistsequences/postgreslistsequences.go @@ -17,11 +17,13 @@ package postgreslistsequences import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,22 +31,22 @@ import ( const resourceType string = "postgres-list-sequences" const listSequencesStatement = ` - SELECT - sequencename as sequence_name, - schemaname as schema_name, - sequenceowner as sequence_owner, - data_type, - start_value, - min_value, - max_value, - increment_by, - last_value - FROM pg_sequences - WHERE - ($1::text IS NULL OR schemaname LIKE '%' || $1 || '%') - AND ($2::text IS NULL OR sequencename LIKE '%' || $2 || '%') - ORDER BY schema_name, sequence_name - LIMIT COALESCE($3::int, 50); + SELECT + sequencename as sequence_name, + schemaname as schema_name, + sequenceowner as sequence_owner, + data_type, + start_value, + min_value, + max_value, + increment_by, + last_value + FROM pg_sequences + WHERE + ($1::text IS NULL OR schemaname LIKE '%' || $1 || '%') + AND ($2::text IS NULL OR sequencename LIKE '%' || $2 || '%') + ORDER BY schema_name, sequence_name + LIMIT COALESCE($3::int, 50); ` @@ -75,7 +77,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -94,7 +95,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -107,7 +107,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -121,21 +120,25 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listSequencesStatement, sliceParams) + resp, err := source.RunSQL(ctx, listSequencesStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgresliststoredprocedure/postgresliststoredprocedure.go b/internal/tools/postgres/postgresliststoredprocedure/postgresliststoredprocedure.go index f8d9891cac..96c727a020 100644 --- a/internal/tools/postgres/postgresliststoredprocedure/postgresliststoredprocedure.go +++ b/internal/tools/postgres/postgresliststoredprocedure/postgresliststoredprocedure.go @@ -17,6 +17,7 @@ package postgresliststoredprocedure import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -25,6 +26,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -32,7 +34,7 @@ import ( const resourceType string = "postgres-list-stored-procedure" const listStoredProcedure = ` - SELECT + SELECT n.nspname AS schema_name, p.proname AS name, r.rolname AS owner, @@ -85,7 +87,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -118,7 +119,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -132,7 +132,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -147,18 +146,18 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() results, err := t.pool.Query(ctx, listStoredProcedure, sliceParams...) if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) + return nil, util.ProcessGeneralError(err) } defer results.Close() @@ -168,7 +167,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para for results.Next() { values, err := results.Values() if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) + return nil, util.NewClientServerError("unable to parse row", http.StatusInternalServerError, err) } rowMap := make(map[string]any) for i, field := range fields { diff --git a/internal/tools/postgres/postgreslisttables/postgreslisttables.go b/internal/tools/postgres/postgreslisttables/postgreslisttables.go index da3ea82af0..70a4b594e9 100644 --- a/internal/tools/postgres/postgreslisttables/postgreslisttables.go +++ b/internal/tools/postgres/postgreslisttables/postgreslisttables.go @@ -17,11 +17,13 @@ package postgreslisttables import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,81 +31,81 @@ import ( const resourceType string = "postgres-list-tables" const listTablesStatement = ` - WITH desired_relkinds AS ( - SELECT ARRAY['r', 'p']::char[] AS kinds -- Always consider both 'TABLE' and 'PARTITIONED TABLE' - ), - table_info AS ( - SELECT - t.oid AS table_oid, - ns.nspname AS schema_name, - t.relname AS table_name, - pg_get_userbyid(t.relowner) AS table_owner, - obj_description(t.oid, 'pg_class') AS table_comment, - t.relkind AS object_kind - FROM - pg_class t - JOIN - pg_namespace ns ON ns.oid = t.relnamespace - CROSS JOIN desired_relkinds dk - WHERE - t.relkind = ANY(dk.kinds) -- Filter by selected table relkinds ('r', 'p') - AND (NULLIF(TRIM($1), '') IS NULL OR t.relname = ANY(string_to_array($1,','))) -- $1 is object_names - AND ns.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast','google_ml') - AND ns.nspname NOT LIKE 'pg_temp_%' AND ns.nspname NOT LIKE 'pg_toast_temp_%' - ), - columns_info AS ( - SELECT - att.attrelid AS table_oid, att.attname AS column_name, format_type(att.atttypid, att.atttypmod) AS data_type, - att.attnum AS column_ordinal_position, att.attnotnull AS is_not_nullable, - pg_get_expr(ad.adbin, ad.adrelid) AS column_default, col_description(att.attrelid, att.attnum) AS column_comment - FROM pg_attribute att LEFT JOIN pg_attrdef ad ON att.attrelid = ad.adrelid AND att.attnum = ad.adnum - JOIN table_info ti ON att.attrelid = ti.table_oid WHERE att.attnum > 0 AND NOT att.attisdropped - ), - constraints_info AS ( - SELECT - con.conrelid AS table_oid, con.conname AS constraint_name, pg_get_constraintdef(con.oid) AS constraint_definition, - CASE con.contype WHEN 'p' THEN 'PRIMARY KEY' WHEN 'f' THEN 'FOREIGN KEY' WHEN 'u' THEN 'UNIQUE' WHEN 'c' THEN 'CHECK' ELSE con.contype::text END AS constraint_type, - (SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.conkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.conrelid AND att.attnum = u.attnum) AS constraint_columns, - NULLIF(con.confrelid, 0)::regclass AS foreign_key_referenced_table, - (SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.confkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.confrelid AND att.attnum = u.attnum WHERE con.contype = 'f') AS foreign_key_referenced_columns - FROM pg_constraint con JOIN table_info ti ON con.conrelid = ti.table_oid - ), - indexes_info AS ( - SELECT - idx.indrelid AS table_oid, ic.relname AS index_name, pg_get_indexdef(idx.indexrelid) AS index_definition, - idx.indisunique AS is_unique, idx.indisprimary AS is_primary, am.amname AS index_method, - (SELECT array_agg(att.attname ORDER BY u.ord) FROM unnest(idx.indkey::int[]) WITH ORDINALITY AS u(colidx, ord) LEFT JOIN pg_attribute att ON att.attrelid = idx.indrelid AND att.attnum = u.colidx WHERE u.colidx <> 0) AS index_columns - FROM pg_index idx JOIN pg_class ic ON ic.oid = idx.indexrelid JOIN pg_am am ON am.oid = ic.relam JOIN table_info ti ON idx.indrelid = ti.table_oid - ), - triggers_info AS ( - SELECT tg.tgrelid AS table_oid, tg.tgname AS trigger_name, pg_get_triggerdef(tg.oid) AS trigger_definition, tg.tgenabled AS trigger_enabled_state - FROM pg_trigger tg JOIN table_info ti ON tg.tgrelid = ti.table_oid WHERE NOT tg.tgisinternal - ) - SELECT - ti.schema_name, - ti.table_name AS object_name, - CASE - WHEN $2 = 'simple' THEN - -- IF format is 'simple', return basic JSON - json_build_object('name', ti.table_name) - ELSE - json_build_object( - 'schema_name', ti.schema_name, - 'object_name', ti.table_name, - 'object_type', CASE ti.object_kind - WHEN 'r' THEN 'TABLE' - WHEN 'p' THEN 'PARTITIONED TABLE' - ELSE ti.object_kind::text -- Should not happen due to WHERE clause - END, - 'owner', ti.table_owner, - 'comment', ti.table_comment, - 'columns', COALESCE((SELECT json_agg(json_build_object('column_name',ci.column_name,'data_type',ci.data_type,'ordinal_position',ci.column_ordinal_position,'is_not_nullable',ci.is_not_nullable,'column_default',ci.column_default,'column_comment',ci.column_comment) ORDER BY ci.column_ordinal_position) FROM columns_info ci WHERE ci.table_oid = ti.table_oid), '[]'::json), - 'constraints', COALESCE((SELECT json_agg(json_build_object('constraint_name',cons.constraint_name,'constraint_type',cons.constraint_type,'constraint_definition',cons.constraint_definition,'constraint_columns',cons.constraint_columns,'foreign_key_referenced_table',cons.foreign_key_referenced_table,'foreign_key_referenced_columns',cons.foreign_key_referenced_columns)) FROM constraints_info cons WHERE cons.table_oid = ti.table_oid), '[]'::json), - 'indexes', COALESCE((SELECT json_agg(json_build_object('index_name',ii.index_name,'index_definition',ii.index_definition,'is_unique',ii.is_unique,'is_primary',ii.is_primary,'index_method',ii.index_method,'index_columns',ii.index_columns)) FROM indexes_info ii WHERE ii.table_oid = ti.table_oid), '[]'::json), - 'triggers', COALESCE((SELECT json_agg(json_build_object('trigger_name',tri.trigger_name,'trigger_definition',tri.trigger_definition,'trigger_enabled_state',tri.trigger_enabled_state)) FROM triggers_info tri WHERE tri.table_oid = ti.table_oid), '[]'::json) - ) - END AS object_details - FROM table_info ti ORDER BY ti.schema_name, ti.table_name; + WITH desired_relkinds AS ( + SELECT ARRAY['r', 'p']::char[] AS kinds -- Always consider both 'TABLE' and 'PARTITIONED TABLE' + ), + table_info AS ( + SELECT + t.oid AS table_oid, + ns.nspname AS schema_name, + t.relname AS table_name, + pg_get_userbyid(t.relowner) AS table_owner, + obj_description(t.oid, 'pg_class') AS table_comment, + t.relkind AS object_kind + FROM + pg_class t + JOIN + pg_namespace ns ON ns.oid = t.relnamespace + CROSS JOIN desired_relkinds dk + WHERE + t.relkind = ANY(dk.kinds) -- Filter by selected table relkinds ('r', 'p') + AND (NULLIF(TRIM($1), '') IS NULL OR t.relname = ANY(string_to_array($1,','))) -- $1 is object_names + AND ns.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast','google_ml') + AND ns.nspname NOT LIKE 'pg_temp_%' AND ns.nspname NOT LIKE 'pg_toast_temp_%' + ), + columns_info AS ( + SELECT + att.attrelid AS table_oid, att.attname AS column_name, format_type(att.atttypid, att.atttypmod) AS data_type, + att.attnum AS column_ordinal_position, att.attnotnull AS is_not_nullable, + pg_get_expr(ad.adbin, ad.adrelid) AS column_default, col_description(att.attrelid, att.attnum) AS column_comment + FROM pg_attribute att LEFT JOIN pg_attrdef ad ON att.attrelid = ad.adrelid AND att.attnum = ad.adnum + JOIN table_info ti ON att.attrelid = ti.table_oid WHERE att.attnum > 0 AND NOT att.attisdropped + ), + constraints_info AS ( + SELECT + con.conrelid AS table_oid, con.conname AS constraint_name, pg_get_constraintdef(con.oid) AS constraint_definition, + CASE con.contype WHEN 'p' THEN 'PRIMARY KEY' WHEN 'f' THEN 'FOREIGN KEY' WHEN 'u' THEN 'UNIQUE' WHEN 'c' THEN 'CHECK' ELSE con.contype::text END AS constraint_type, + (SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.conkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.conrelid AND att.attnum = u.attnum) AS constraint_columns, + NULLIF(con.confrelid, 0)::regclass AS foreign_key_referenced_table, + (SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.confkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.confrelid AND att.attnum = u.attnum WHERE con.contype = 'f') AS foreign_key_referenced_columns + FROM pg_constraint con JOIN table_info ti ON con.conrelid = ti.table_oid + ), + indexes_info AS ( + SELECT + idx.indrelid AS table_oid, ic.relname AS index_name, pg_get_indexdef(idx.indexrelid) AS index_definition, + idx.indisunique AS is_unique, idx.indisprimary AS is_primary, am.amname AS index_method, + (SELECT array_agg(att.attname ORDER BY u.ord) FROM unnest(idx.indkey::int[]) WITH ORDINALITY AS u(colidx, ord) LEFT JOIN pg_attribute att ON att.attrelid = idx.indrelid AND att.attnum = u.colidx WHERE u.colidx <> 0) AS index_columns + FROM pg_index idx JOIN pg_class ic ON ic.oid = idx.indexrelid JOIN pg_am am ON am.oid = ic.relam JOIN table_info ti ON idx.indrelid = ti.table_oid + ), + triggers_info AS ( + SELECT tg.tgrelid AS table_oid, tg.tgname AS trigger_name, pg_get_triggerdef(tg.oid) AS trigger_definition, tg.tgenabled AS trigger_enabled_state + FROM pg_trigger tg JOIN table_info ti ON tg.tgrelid = ti.table_oid WHERE NOT tg.tgisinternal + ) + SELECT + ti.schema_name, + ti.table_name AS object_name, + CASE + WHEN $2 = 'simple' THEN + -- IF format is 'simple', return basic JSON + json_build_object('name', ti.table_name) + ELSE + json_build_object( + 'schema_name', ti.schema_name, + 'object_name', ti.table_name, + 'object_type', CASE ti.object_kind + WHEN 'r' THEN 'TABLE' + WHEN 'p' THEN 'PARTITIONED TABLE' + ELSE ti.object_kind::text -- Should not happen due to WHERE clause + END, + 'owner', ti.table_owner, + 'comment', ti.table_comment, + 'columns', COALESCE((SELECT json_agg(json_build_object('column_name',ci.column_name,'data_type',ci.data_type,'ordinal_position',ci.column_ordinal_position,'is_not_nullable',ci.is_not_nullable,'column_default',ci.column_default,'column_comment',ci.column_comment) ORDER BY ci.column_ordinal_position) FROM columns_info ci WHERE ci.table_oid = ti.table_oid), '[]'::json), + 'constraints', COALESCE((SELECT json_agg(json_build_object('constraint_name',cons.constraint_name,'constraint_type',cons.constraint_type,'constraint_definition',cons.constraint_definition,'constraint_columns',cons.constraint_columns,'foreign_key_referenced_table',cons.foreign_key_referenced_table,'foreign_key_referenced_columns',cons.foreign_key_referenced_columns)) FROM constraints_info cons WHERE cons.table_oid = ti.table_oid), '[]'::json), + 'indexes', COALESCE((SELECT json_agg(json_build_object('index_name',ii.index_name,'index_definition',ii.index_definition,'is_unique',ii.is_unique,'is_primary',ii.is_primary,'index_method',ii.index_method,'index_columns',ii.index_columns)) FROM indexes_info ii WHERE ii.table_oid = ti.table_oid), '[]'::json), + 'triggers', COALESCE((SELECT json_agg(json_build_object('trigger_name',tri.trigger_name,'trigger_definition',tri.trigger_definition,'trigger_enabled_state',tri.trigger_enabled_state)) FROM triggers_info tri WHERE tri.table_oid = ti.table_oid), '[]'::json) + ) + END AS object_details + FROM table_info ti ORDER BY ti.schema_name, ti.table_name; ` func init() { @@ -133,7 +135,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -158,7 +159,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return t, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -168,31 +168,31 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() tableNames, ok := paramsMap["table_names"].(string) if !ok { - return nil, fmt.Errorf("invalid 'table_names' parameter; expected a string") + return nil, util.NewAgentError("invalid 'table_names' parameter; expected a string", nil) } outputFormat, _ := paramsMap["output_format"].(string) if outputFormat != "simple" && outputFormat != "detailed" { - return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat) + return nil, util.NewAgentError(fmt.Sprintf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat), nil) } resp, err := source.RunSQL(ctx, listTablesStatement, []any{tableNames, outputFormat}) if err != nil { - return nil, err + return nil, util.ProcessGeneralError(err) } resSlice, ok := resp.([]any) if !ok || len(resSlice) == 0 { return []any{}, nil } - return resp, err + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces.go b/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces.go index 588caf8117..a5a3296dec 100644 --- a/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces.go +++ b/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces.go @@ -17,11 +17,13 @@ package postgreslisttablespaces import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,29 +31,29 @@ import ( const resourceType string = "postgres-list-tablespaces" const listTableSpacesStatement = ` - WITH - tablespace_info AS ( - SELECT - spcname AS tablespace_name, - pg_catalog.pg_get_userbyid(spcowner) AS owner_name, - CASE - WHEN pg_catalog.has_tablespace_privilege(oid, 'CREATE') THEN pg_tablespace_size(oid) - ELSE NULL - END AS size_in_bytes, - oid, - spcacl, - spcoptions - FROM - pg_tablespace - ) - SELECT * - FROM - tablespace_info - WHERE - ($1::text IS NULL OR tablespace_name LIKE '%' || $1::text || '%') - ORDER BY - tablespace_name - LIMIT COALESCE($2::int, 50); + WITH + tablespace_info AS ( + SELECT + spcname AS tablespace_name, + pg_catalog.pg_get_userbyid(spcowner) AS owner_name, + CASE + WHEN pg_catalog.has_tablespace_privilege(oid, 'CREATE') THEN pg_tablespace_size(oid) + ELSE NULL + END AS size_in_bytes, + oid, + spcacl, + spcoptions + FROM + pg_tablespace + ) + SELECT * + FROM + tablespace_info + WHERE + ($1::text IS NULL OR tablespace_name LIKE '%' || $1::text || '%') + ORDER BY + tablespace_name + LIMIT COALESCE($2::int, 50); ` func init() { @@ -81,7 +83,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -99,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -112,7 +112,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -126,24 +125,28 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() tablespaceName, ok := paramsMap["tablespace_name"].(string) if !ok { - return nil, fmt.Errorf("invalid 'tablespace_name' parameter; expected a string") + return nil, util.NewAgentError("invalid 'tablespace_name' parameter; expected a string", nil) } limit, ok := paramsMap["limit"].(int) if !ok { - return nil, fmt.Errorf("invalid 'limit' parameter; expected an integer") + return nil, util.NewAgentError("invalid 'limit' parameter; expected an integer", nil) } - return source.RunSQL(ctx, listTableSpacesStatement, []any{tablespaceName, limit}) + resp, err := source.RunSQL(ctx, listTableSpacesStatement, []any{tablespaceName, limit}) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats.go b/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats.go index 8e5d8e3309..13c4a9b05c 100644 --- a/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats.go +++ b/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats.go @@ -17,11 +17,13 @@ package postgreslisttablestats import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,7 +31,7 @@ import ( const resourceType string = "postgres-list-table-stats" const listTableStats = ` - WITH table_stats AS ( + WITH table_stats AS ( SELECT s.schemaname AS schema_name, s.relname AS table_name, @@ -102,7 +104,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -121,19 +122,18 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) if cfg.Description == "" { cfg.Description = `Lists the user table statistics in the database ordered by number of - sequential scans with a default limit of 50 rows. Returns the following - columns: schema name, table name, table size in bytes, number of - sequential scans, number of index scans, idx_scan_ratio_percent (showing - the percentage of total scans that utilized an index, where a low ratio - indicates missing or ineffective indexes), number of live rows, number - of dead rows, dead_row_ratio_percent (indicating potential table bloat), - total number of rows inserted, updated, and deleted, the timestamps - for the last_vacuum, last_autovacuum, and last_autoanalyze operations.` + sequential scans with a default limit of 50 rows. Returns the following + columns: schema name, table name, table size in bytes, number of + sequential scans, number of index scans, idx_scan_ratio_percent (showing + the percentage of total scans that utilized an index, where a low ratio + indicates missing or ineffective indexes), number of live rows, number + of dead rows, dead_row_ratio_percent (indicating potential table bloat), + total number of rows inserted, updated, and deleted, the timestamps + for the last_vacuum, last_autovacuum, and last_autoanalyze operations.` } mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -146,7 +146,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -160,21 +159,25 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listTableStats, sliceParams) + resp, err := source.RunSQL(ctx, listTableStats, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslisttriggers/postgreslisttriggers.go b/internal/tools/postgres/postgreslisttriggers/postgreslisttriggers.go index 810f242e62..63889bfb46 100644 --- a/internal/tools/postgres/postgreslisttriggers/postgreslisttriggers.go +++ b/internal/tools/postgres/postgreslisttriggers/postgreslisttriggers.go @@ -17,11 +17,13 @@ package postgreslisttriggers import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,49 +31,49 @@ import ( const resourceType string = "postgres-list-triggers" const listTriggersStatement = ` - WITH - trigger_list AS ( - SELECT - t.tgname AS trigger_name, - n.nspname AS schema_name, - c.relname AS table_name, - CASE t.tgenabled - WHEN 'O' THEN 'ENABLED' - WHEN 'D' THEN 'DISABLED' - WHEN 'R' THEN 'REPLICA' - WHEN 'A' THEN 'ALWAYS' - END AS status, - CASE - WHEN (t.tgtype::int & 2) = 2 THEN 'BEFORE' - WHEN (t.tgtype::int & 64) = 64 THEN 'INSTEAD OF' - ELSE 'AFTER' - END AS timing, - concat_ws( - ', ', - CASE WHEN (t.tgtype::int & 4) = 4 THEN 'INSERT' END, - CASE WHEN (t.tgtype::int & 16) = 16 THEN 'UPDATE' END, - CASE WHEN (t.tgtype::int & 8) = 8 THEN 'DELETE' END, - CASE WHEN (t.tgtype::int & 32) = 32 THEN 'TRUNCATE' END) AS events, - CASE WHEN (t.tgtype::int & 1) = 1 THEN 'ROW' ELSE 'STATEMENT' END AS activation_level, - p.proname AS function_name, - pg_get_triggerdef(t.oid) AS definition - FROM pg_trigger t - JOIN pg_class c - ON t.tgrelid = c.oid - JOIN pg_namespace n - ON c.relnamespace = n.oid - LEFT JOIN pg_proc p - ON t.tgfoid = p.oid - WHERE NOT t.tgisinternal - ) - SELECT * - FROM trigger_list - WHERE - ($1::text IS NULL OR trigger_name LIKE '%' || $1::text || '%') - AND ($2::text IS NULL OR schema_name LIKE '%' || $2::text || '%') - AND ($3::text IS NULL OR table_name LIKE '%' || $3::text || '%') - ORDER BY schema_name, table_name, trigger_name - LIMIT COALESCE($4::int, 50); + WITH + trigger_list AS ( + SELECT + t.tgname AS trigger_name, + n.nspname AS schema_name, + c.relname AS table_name, + CASE t.tgenabled + WHEN 'O' THEN 'ENABLED' + WHEN 'D' THEN 'DISABLED' + WHEN 'R' THEN 'REPLICA' + WHEN 'A' THEN 'ALWAYS' + END AS status, + CASE + WHEN (t.tgtype::int & 2) = 2 THEN 'BEFORE' + WHEN (t.tgtype::int & 64) = 64 THEN 'INSTEAD OF' + ELSE 'AFTER' + END AS timing, + concat_ws( + ', ', + CASE WHEN (t.tgtype::int & 4) = 4 THEN 'INSERT' END, + CASE WHEN (t.tgtype::int & 16) = 16 THEN 'UPDATE' END, + CASE WHEN (t.tgtype::int & 8) = 8 THEN 'DELETE' END, + CASE WHEN (t.tgtype::int & 32) = 32 THEN 'TRUNCATE' END) AS events, + CASE WHEN (t.tgtype::int & 1) = 1 THEN 'ROW' ELSE 'STATEMENT' END AS activation_level, + p.proname AS function_name, + pg_get_triggerdef(t.oid) AS definition + FROM pg_trigger t + JOIN pg_class c + ON t.tgrelid = c.oid + JOIN pg_namespace n + ON c.relnamespace = n.oid + LEFT JOIN pg_proc p + ON t.tgfoid = p.oid + WHERE NOT t.tgisinternal + ) + SELECT * + FROM trigger_list + WHERE + ($1::text IS NULL OR trigger_name LIKE '%' || $1::text || '%') + AND ($2::text IS NULL OR schema_name LIKE '%' || $2::text || '%') + AND ($3::text IS NULL OR table_name LIKE '%' || $3::text || '%') + ORDER BY schema_name, table_name, trigger_name + LIMIT COALESCE($4::int, 50); ` func init() { @@ -101,7 +103,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -121,7 +122,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -134,7 +134,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -148,20 +147,24 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listTriggersStatement, sliceParams) + resp, err := source.RunSQL(ctx, listTriggersStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistviews/postgreslistviews.go b/internal/tools/postgres/postgreslistviews/postgreslistviews.go index e2d49691fa..e4359b9759 100644 --- a/internal/tools/postgres/postgreslistviews/postgreslistviews.go +++ b/internal/tools/postgres/postgreslistviews/postgreslistviews.go @@ -17,11 +17,13 @@ package postgreslistviews import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,24 +31,24 @@ import ( const resourceType string = "postgres-list-views" const listViewsStatement = ` - WITH list_views AS ( - SELECT - schemaname AS schema_name, - viewname AS view_name, - viewowner AS owner_name, - definition - FROM pg_views - ) - SELECT * - FROM list_views - WHERE - schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast') - AND schema_name NOT LIKE 'pg_temp_%' - AND ($1::text IS NULL OR view_name ILIKE '%' || $1::text || '%') - AND ($2::text IS NULL OR schema_name ILIKE '%' || $2::text || '%') - ORDER BY - schema_name, view_name - LIMIT COALESCE($3::int, 50); + WITH list_views AS ( + SELECT + schemaname AS schema_name, + viewname AS view_name, + viewowner AS owner_name, + definition + FROM pg_views + ) + SELECT * + FROM list_views + WHERE + schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast') + AND schema_name NOT LIKE 'pg_temp_%' + AND ($1::text IS NULL OR view_name ILIKE '%' || $1::text || '%') + AND ($2::text IS NULL OR schema_name ILIKE '%' || $2::text || '%') + ORDER BY + schema_name, view_name + LIMIT COALESCE($3::int, 50); ` func init() { @@ -76,7 +78,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -95,7 +96,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -108,7 +108,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -118,20 +117,24 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, listViewsStatement, sliceParams) + resp, err := source.RunSQL(ctx, listViewsStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslongrunningtransactions/postgreslongrunningtransactions.go b/internal/tools/postgres/postgreslongrunningtransactions/postgreslongrunningtransactions.go index af35731d0c..2664c2e419 100644 --- a/internal/tools/postgres/postgreslongrunningtransactions/postgreslongrunningtransactions.go +++ b/internal/tools/postgres/postgreslongrunningtransactions/postgreslongrunningtransactions.go @@ -17,11 +17,13 @@ package postgreslongrunningtransactions import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,7 +31,7 @@ import ( const resourceType string = "postgres-long-running-transactions" const longRunningTransactions = ` - SELECT + SELECT pid, datname, usename, @@ -83,7 +85,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -103,7 +104,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -116,7 +116,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -130,20 +129,24 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, longRunningTransactions, sliceParams) + resp, err := source.RunSQL(ctx, longRunningTransactions, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgresreplicationstats/postgresreplicationstats.go b/internal/tools/postgres/postgresreplicationstats/postgresreplicationstats.go index bdb45cda4e..495c640140 100644 --- a/internal/tools/postgres/postgresreplicationstats/postgresreplicationstats.go +++ b/internal/tools/postgres/postgresreplicationstats/postgresreplicationstats.go @@ -17,11 +17,13 @@ package postgresreplicationstats import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -29,11 +31,11 @@ import ( const resourceType string = "postgres-replication-stats" const replicationStats = ` - SELECT - pid, - usename, + SELECT + pid, + usename, application_name, - backend_xmin, + backend_xmin, client_addr, state, sync_state, @@ -73,7 +75,6 @@ type Config struct { AuthRequired []string `yaml:"authRequired"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -90,7 +91,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup return Tool{ Config: cfg, allParams: allParameters, @@ -103,7 +103,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) }, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -117,20 +116,24 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, replicationStats, sliceParams) + resp, err := source.RunSQL(ctx, replicationStats, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgressql/postgressql.go b/internal/tools/postgres/postgressql/postgressql.go index adfc6e830c..ece775a356 100644 --- a/internal/tools/postgres/postgressql/postgressql.go +++ b/internal/tools/postgres/postgressql/postgressql.go @@ -17,11 +17,13 @@ package postgressql import ( "context" "fmt" + "net/http" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -58,7 +60,6 @@ type Config struct { TemplateParameters parameters.Parameters `yaml:"templateParameters"` } -// validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigType() string { @@ -73,7 +74,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - // finish tool setup t := Tool{ Config: cfg, AllParams: allParameters, @@ -83,7 +83,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return t, nil } -// validate interface var _ tools.Tool = Tool{} type Tool struct { @@ -93,24 +92,28 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) if err != nil { - return nil, err + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) } paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract template params %w", err) + return nil, util.NewAgentError("unable to extract template params", err) } newParams, err := parameters.GetParams(t.Parameters, paramsMap) if err != nil { - return nil, fmt.Errorf("unable to extract standard params %w", err) + return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() - return source.RunSQL(ctx, newStatement, sliceParams) + resp, err := source.RunSQL(ctx, newStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil } func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {