From 934751b753c0b20cedcb128e54287573b2a367a6 Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Mon, 9 Feb 2026 19:51:06 -0800 Subject: [PATCH] refactor: refactor subcommands and move tests to its own package --- cmd/imports/imports.go | 253 ++ cmd/options.go | 30 - cmd/root.go | 876 +------ cmd/root_test.go | 2204 +---------------- internal/cli/invoke/command.go | 58 +- .../cli/invoke/command_test.go | 30 +- internal/cli/options.go | 251 ++ {cmd => internal/cli}/options_test.go | 33 +- internal/cli/persistent_flags.go | 46 + internal/cli/skills/command.go | 83 +- .../cli/skills/command_test.go | 34 +- internal/cli/tools_file.go | 349 +++ internal/cli/tools_file_test.go | 2143 ++++++++++++++++ tests/server.go | 7 +- 14 files changed, 3265 insertions(+), 3132 deletions(-) create mode 100644 cmd/imports/imports.go delete mode 100644 cmd/options.go rename cmd/invoke_tool_test.go => internal/cli/invoke/command_test.go (80%) create mode 100644 internal/cli/options.go rename {cmd => internal/cli}/options_test.go (62%) create mode 100644 internal/cli/persistent_flags.go rename cmd/skill_generate_test.go => internal/cli/skills/command_test.go (87%) create mode 100644 internal/cli/tools_file.go create mode 100644 internal/cli/tools_file_test.go diff --git a/cmd/imports/imports.go b/cmd/imports/imports.go new file mode 100644 index 0000000000..4549fa7988 --- /dev/null +++ b/cmd/imports/imports.go @@ -0,0 +1,253 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package imports + +import ( + // Import prompt packages for side effect of registration + _ "github.com/googleapis/genai-toolbox/internal/prompts/custom" + + // Import tool packages for side effect of registration + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbcreatecluster" + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbcreateinstance" + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbcreateuser" + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbgetcluster" + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbgetinstance" + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbgetuser" + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydblistclusters" + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydblistinstances" + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydblistusers" + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbwaitforoperation" + _ "github.com/googleapis/genai-toolbox/internal/tools/alloydbainl" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryanalyzecontribution" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryconversationalanalytics" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryforecast" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerygetdatasetinfo" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerygettableinfo" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerylistdatasetids" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerylisttableids" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysearchcatalog" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysql" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigtable" + _ "github.com/googleapis/genai-toolbox/internal/tools/cassandra/cassandracql" + _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouseexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselistdatabases" + _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselisttables" + _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhousesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudgda" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetdataset" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstore" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstoremetrics" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetfhirresource" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstore" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstoremetrics" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarelistdicomstores" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarelistfhirstores" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcareretrieverendereddicominstance" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicominstances" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomseries" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudloggingadmin/cloudloggingadminlistlognames" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudloggingadmin/cloudloggingadminlistresourcetypes" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudloggingadmin/cloudloggingadminquerylogs" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudmonitoring" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcloneinstance" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatebackup" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatedatabase" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreateusers" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlgetinstances" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistdatabases" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistinstances" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlrestorebackup" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlwaitforoperation" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlpg/cloudsqlpgupgradeprecheck" + _ "github.com/googleapis/genai-toolbox/internal/tools/couchbase" + _ "github.com/googleapis/genai-toolbox/internal/tools/dataform/dataformcompilelocal" + _ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexlookupentry" + _ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexsearchaspecttypes" + _ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexsearchentries" + _ "github.com/googleapis/genai-toolbox/internal/tools/dgraph" + _ "github.com/googleapis/genai-toolbox/internal/tools/elasticsearch/elasticsearchesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/firebird/firebirdexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/firebird/firebirdsql" + _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoreadddocuments" + _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoredeletedocuments" + _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoregetdocuments" + _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoregetrules" + _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorelistcollections" + _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorequery" + _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorequerycollection" + _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoreupdatedocument" + _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorevalidaterules" + _ "github.com/googleapis/genai-toolbox/internal/tools/http" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardelement" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardfilter" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerconversationalanalytics" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercreateprojectfile" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdeleteprojectfile" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdevmode" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergenerateembedurl" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiondatabases" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnections" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectionschemas" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontablecolumns" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontables" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdashboards" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdimensions" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetexplores" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetfilters" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetlooks" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetmeasures" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetmodels" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetparameters" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetprojectfile" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetprojectfiles" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetprojects" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerhealthanalyze" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerhealthpulse" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerhealthvacuum" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookermakedashboard" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookermakelook" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerquery" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerquerysql" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerqueryurl" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerrundashboard" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerrunlook" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerupdateprojectfile" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookervalidateproject" + _ "github.com/googleapis/genai-toolbox/internal/tools/mindsdb/mindsdbexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/mindsdb/mindsdbsql" + _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbaggregate" + _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbdeletemany" + _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbdeleteone" + _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbfind" + _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbfindone" + _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbinsertmany" + _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbinsertone" + _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbupdatemany" + _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbupdateone" + _ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqllisttables" + _ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlsql" + _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlgetqueryplan" + _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllistactivequeries" + _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttablefragmentation" + _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttables" + _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttablesmissinguniqueindexes" + _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlsql" + _ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jcypher" + _ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jexecutecypher" + _ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema" + _ "github.com/googleapis/genai-toolbox/internal/tools/oceanbase/oceanbaseexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/oceanbase/oceanbasesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/oracle/oracleexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/oracle/oraclesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresdatabaseoverview" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresgetcolumncardinality" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistactivequeries" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistavailableextensions" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistdatabasestats" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistindexes" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistinstalledextensions" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistlocks" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistpgsettings" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistpublicationtables" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistquerystats" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistroles" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistschemas" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistsequences" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresliststoredprocedure" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttables" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttablespaces" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttablestats" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttriggers" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistviews" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslongrunningtransactions" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresreplicationstats" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql" + _ "github.com/googleapis/genai-toolbox/internal/tools/redis" + _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcancelbatch" + _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch" + _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcreatesparkbatch" + _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch" + _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches" + _ "github.com/googleapis/genai-toolbox/internal/tools/singlestore/singlestoreexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/singlestore/singlestoresql" + _ "github.com/googleapis/genai-toolbox/internal/tools/snowflake/snowflakeexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/snowflake/snowflakesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlistgraphs" + _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlisttables" + _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannersql" + _ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqliteexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqlitesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbsql" + _ "github.com/googleapis/genai-toolbox/internal/tools/trino/trinoexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/trino/trinosql" + _ "github.com/googleapis/genai-toolbox/internal/tools/utility/wait" + _ "github.com/googleapis/genai-toolbox/internal/tools/valkey" + _ "github.com/googleapis/genai-toolbox/internal/tools/yugabytedbsql" + + // Import source packages for side effect of registration + _ "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" + _ "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" + _ "github.com/googleapis/genai-toolbox/internal/sources/bigquery" + _ "github.com/googleapis/genai-toolbox/internal/sources/bigtable" + _ "github.com/googleapis/genai-toolbox/internal/sources/cassandra" + _ "github.com/googleapis/genai-toolbox/internal/sources/clickhouse" + _ "github.com/googleapis/genai-toolbox/internal/sources/cloudgda" + _ "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare" + _ "github.com/googleapis/genai-toolbox/internal/sources/cloudloggingadmin" + _ "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring" + _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" + _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql" + _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql" + _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" + _ "github.com/googleapis/genai-toolbox/internal/sources/couchbase" + _ "github.com/googleapis/genai-toolbox/internal/sources/dataplex" + _ "github.com/googleapis/genai-toolbox/internal/sources/dgraph" + _ "github.com/googleapis/genai-toolbox/internal/sources/elasticsearch" + _ "github.com/googleapis/genai-toolbox/internal/sources/firebird" + _ "github.com/googleapis/genai-toolbox/internal/sources/firestore" + _ "github.com/googleapis/genai-toolbox/internal/sources/http" + _ "github.com/googleapis/genai-toolbox/internal/sources/looker" + _ "github.com/googleapis/genai-toolbox/internal/sources/mindsdb" + _ "github.com/googleapis/genai-toolbox/internal/sources/mongodb" + _ "github.com/googleapis/genai-toolbox/internal/sources/mssql" + _ "github.com/googleapis/genai-toolbox/internal/sources/mysql" + _ "github.com/googleapis/genai-toolbox/internal/sources/neo4j" + _ "github.com/googleapis/genai-toolbox/internal/sources/oceanbase" + _ "github.com/googleapis/genai-toolbox/internal/sources/oracle" + _ "github.com/googleapis/genai-toolbox/internal/sources/postgres" + _ "github.com/googleapis/genai-toolbox/internal/sources/redis" + _ "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark" + _ "github.com/googleapis/genai-toolbox/internal/sources/singlestore" + _ "github.com/googleapis/genai-toolbox/internal/sources/snowflake" + _ "github.com/googleapis/genai-toolbox/internal/sources/spanner" + _ "github.com/googleapis/genai-toolbox/internal/sources/sqlite" + _ "github.com/googleapis/genai-toolbox/internal/sources/tidb" + _ "github.com/googleapis/genai-toolbox/internal/sources/trino" + _ "github.com/googleapis/genai-toolbox/internal/sources/valkey" + _ "github.com/googleapis/genai-toolbox/internal/sources/yugabytedb" +) diff --git a/cmd/options.go b/cmd/options.go deleted file mode 100644 index b87a7e6d55..0000000000 --- a/cmd/options.go +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cmd - -import ( - "io" -) - -// Option is a function that configures a Command. -type Option func(*Command) - -// WithStreams overrides the default writer. -func WithStreams(out, err io.Writer) Option { - return func(c *Command) { - c.outStream = out - c.errStream = err - } -} diff --git a/cmd/root.go b/cmd/root.go index 5e59997211..4aede9b2be 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -15,16 +15,13 @@ package cmd import ( - "bytes" "context" _ "embed" "fmt" - "io" "maps" "os" "os/signal" "path/filepath" - "regexp" "runtime" "slices" "strings" @@ -32,256 +29,20 @@ import ( "time" "github.com/fsnotify/fsnotify" - yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/auth" + "github.com/googleapis/genai-toolbox/internal/cli" "github.com/googleapis/genai-toolbox/internal/cli/invoke" "github.com/googleapis/genai-toolbox/internal/cli/skills" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" - "github.com/googleapis/genai-toolbox/internal/log" - "github.com/googleapis/genai-toolbox/internal/prebuiltconfigs" "github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/server" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/telemetry" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" - - // Import prompt packages for side effect of registration - _ "github.com/googleapis/genai-toolbox/internal/prompts/custom" - - // Import tool packages for side effect of registration - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbcreatecluster" - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbcreateinstance" - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbcreateuser" - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbgetcluster" - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbgetinstance" - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbgetuser" - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydblistclusters" - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydblistinstances" - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydblistusers" - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbwaitforoperation" - _ "github.com/googleapis/genai-toolbox/internal/tools/alloydbainl" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryanalyzecontribution" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryconversationalanalytics" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryforecast" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerygetdatasetinfo" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerygettableinfo" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerylistdatasetids" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerylisttableids" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysearchcatalog" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysql" - _ "github.com/googleapis/genai-toolbox/internal/tools/bigtable" - _ "github.com/googleapis/genai-toolbox/internal/tools/cassandra/cassandracql" - _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouseexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselistdatabases" - _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselisttables" - _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhousesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudgda" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetdataset" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstore" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstoremetrics" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetfhirresource" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstore" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstoremetrics" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarelistdicomstores" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarelistfhirstores" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcareretrieverendereddicominstance" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicominstances" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomseries" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudloggingadmin/cloudloggingadminlistlognames" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudloggingadmin/cloudloggingadminlistresourcetypes" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudloggingadmin/cloudloggingadminquerylogs" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudmonitoring" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcloneinstance" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatebackup" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatedatabase" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreateusers" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlgetinstances" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistdatabases" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistinstances" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlrestorebackup" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlwaitforoperation" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances" - _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlpg/cloudsqlpgupgradeprecheck" - _ "github.com/googleapis/genai-toolbox/internal/tools/couchbase" - _ "github.com/googleapis/genai-toolbox/internal/tools/dataform/dataformcompilelocal" - _ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexlookupentry" - _ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexsearchaspecttypes" - _ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexsearchentries" - _ "github.com/googleapis/genai-toolbox/internal/tools/dgraph" - _ "github.com/googleapis/genai-toolbox/internal/tools/elasticsearch/elasticsearchesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/firebird/firebirdexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/firebird/firebirdsql" - _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoreadddocuments" - _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoredeletedocuments" - _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoregetdocuments" - _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoregetrules" - _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorelistcollections" - _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorequery" - _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorequerycollection" - _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoreupdatedocument" - _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorevalidaterules" - _ "github.com/googleapis/genai-toolbox/internal/tools/http" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardelement" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardfilter" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerconversationalanalytics" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercreateprojectfile" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdeleteprojectfile" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdevmode" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergenerateembedurl" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiondatabases" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnections" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectionschemas" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontablecolumns" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontables" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdashboards" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdimensions" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetexplores" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetfilters" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetlooks" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetmeasures" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetmodels" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetparameters" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetprojectfile" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetprojectfiles" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetprojects" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerhealthanalyze" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerhealthpulse" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerhealthvacuum" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookermakedashboard" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookermakelook" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerquery" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerquerysql" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerqueryurl" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerrundashboard" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerrunlook" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerupdateprojectfile" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookervalidateproject" - _ "github.com/googleapis/genai-toolbox/internal/tools/mindsdb/mindsdbexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/mindsdb/mindsdbsql" - _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbaggregate" - _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbdeletemany" - _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbdeleteone" - _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbfind" - _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbfindone" - _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbinsertmany" - _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbinsertone" - _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbupdatemany" - _ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbupdateone" - _ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqllisttables" - _ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlsql" - _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlgetqueryplan" - _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllistactivequeries" - _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttablefragmentation" - _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttables" - _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttablesmissinguniqueindexes" - _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlsql" - _ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jcypher" - _ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jexecutecypher" - _ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema" - _ "github.com/googleapis/genai-toolbox/internal/tools/oceanbase/oceanbaseexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/oceanbase/oceanbasesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/oracle/oracleexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/oracle/oraclesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresdatabaseoverview" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresgetcolumncardinality" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistactivequeries" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistavailableextensions" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistdatabasestats" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistindexes" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistinstalledextensions" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistlocks" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistpgsettings" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistpublicationtables" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistquerystats" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistroles" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistschemas" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistsequences" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresliststoredprocedure" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttables" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttablespaces" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttablestats" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttriggers" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistviews" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslongrunningtransactions" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresreplicationstats" - _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql" - _ "github.com/googleapis/genai-toolbox/internal/tools/redis" - _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcancelbatch" - _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch" - _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcreatesparkbatch" - _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch" - _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches" - _ "github.com/googleapis/genai-toolbox/internal/tools/singlestore/singlestoreexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/singlestore/singlestoresql" - _ "github.com/googleapis/genai-toolbox/internal/tools/snowflake/snowflakeexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/snowflake/snowflakesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlistgraphs" - _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlisttables" - _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannersql" - _ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqliteexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqlitesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbsql" - _ "github.com/googleapis/genai-toolbox/internal/tools/trino/trinoexecutesql" - _ "github.com/googleapis/genai-toolbox/internal/tools/trino/trinosql" - _ "github.com/googleapis/genai-toolbox/internal/tools/utility/wait" - _ "github.com/googleapis/genai-toolbox/internal/tools/valkey" - _ "github.com/googleapis/genai-toolbox/internal/tools/yugabytedbsql" - "github.com/spf13/cobra" - _ "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" - _ "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - _ "github.com/googleapis/genai-toolbox/internal/sources/bigquery" - _ "github.com/googleapis/genai-toolbox/internal/sources/bigtable" - _ "github.com/googleapis/genai-toolbox/internal/sources/cassandra" - _ "github.com/googleapis/genai-toolbox/internal/sources/clickhouse" - _ "github.com/googleapis/genai-toolbox/internal/sources/cloudgda" - _ "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare" - _ "github.com/googleapis/genai-toolbox/internal/sources/cloudloggingadmin" - _ "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring" - _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" - _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql" - _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql" - _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - _ "github.com/googleapis/genai-toolbox/internal/sources/couchbase" - _ "github.com/googleapis/genai-toolbox/internal/sources/dataplex" - _ "github.com/googleapis/genai-toolbox/internal/sources/dgraph" - _ "github.com/googleapis/genai-toolbox/internal/sources/elasticsearch" - _ "github.com/googleapis/genai-toolbox/internal/sources/firebird" - _ "github.com/googleapis/genai-toolbox/internal/sources/firestore" - _ "github.com/googleapis/genai-toolbox/internal/sources/http" - _ "github.com/googleapis/genai-toolbox/internal/sources/looker" - _ "github.com/googleapis/genai-toolbox/internal/sources/mindsdb" - _ "github.com/googleapis/genai-toolbox/internal/sources/mongodb" - _ "github.com/googleapis/genai-toolbox/internal/sources/mssql" - _ "github.com/googleapis/genai-toolbox/internal/sources/mysql" - _ "github.com/googleapis/genai-toolbox/internal/sources/neo4j" - _ "github.com/googleapis/genai-toolbox/internal/sources/oceanbase" - _ "github.com/googleapis/genai-toolbox/internal/sources/oracle" - _ "github.com/googleapis/genai-toolbox/internal/sources/postgres" - _ "github.com/googleapis/genai-toolbox/internal/sources/redis" - _ "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark" - _ "github.com/googleapis/genai-toolbox/internal/sources/singlestore" - _ "github.com/googleapis/genai-toolbox/internal/sources/snowflake" - _ "github.com/googleapis/genai-toolbox/internal/sources/spanner" - _ "github.com/googleapis/genai-toolbox/internal/sources/sqlite" - _ "github.com/googleapis/genai-toolbox/internal/sources/tidb" - _ "github.com/googleapis/genai-toolbox/internal/sources/trino" - _ "github.com/googleapis/genai-toolbox/internal/sources/valkey" - _ "github.com/googleapis/genai-toolbox/internal/sources/yugabytedb" + // Import prompt packages for side effect of registration + _ "github.com/googleapis/genai-toolbox/cmd/imports" ) var ( @@ -313,422 +74,64 @@ func semanticVersion() string { // Execute adds all child commands to the root command and sets flags appropriately. // This is called by main.main(). It only needs to happen once to the rootCmd. func Execute() { - if err := NewCommand().Execute(); err != nil { + // Initialize cli deps + opts := cli.NewToolboxOptions() + + if err := NewCommand(opts).Execute(); err != nil { exit := 1 os.Exit(exit) } } -// Command represents an invocation of the CLI. -type Command struct { - *cobra.Command - - cfg server.ServerConfig - logger log.Logger - tools_file string - tools_files []string - tools_folder string - prebuiltConfigs []string - inStream io.Reader - outStream io.Writer - errStream io.Writer -} - // NewCommand returns a Command object representing an invocation of the CLI. -func NewCommand(opts ...Option) *Command { - in := os.Stdin - out := os.Stdout - err := os.Stderr - - baseCmd := &cobra.Command{ +func NewCommand(opts *cli.ToolboxOptions) *cobra.Command { + cmd := &cobra.Command{ Use: "toolbox", Version: versionString, SilenceErrors: true, } - cmd := &Command{ - Command: baseCmd, - inStream: in, - outStream: out, - errStream: err, - } - - for _, o := range opts { - o(cmd) - } // Do not print Usage on runtime error cmd.SilenceUsage = true // Set server version - cmd.cfg.Version = versionString + opts.Cfg.Version = versionString // set baseCmd in, out and err the same as cmd. - baseCmd.SetIn(cmd.inStream) - baseCmd.SetOut(cmd.outStream) - baseCmd.SetErr(cmd.errStream) + cmd.SetIn(opts.IOStreams.In) + cmd.SetOut(opts.IOStreams.Out) + cmd.SetErr(opts.IOStreams.ErrOut) + + // setup flags that are common across all commands + cli.PersistentFlags(cmd, opts) flags := cmd.Flags() - persistentFlags := cmd.PersistentFlags() - flags.StringVarP(&cmd.cfg.Address, "address", "a", "127.0.0.1", "Address of the interface the server will listen on.") - flags.IntVarP(&cmd.cfg.Port, "port", "p", 5000, "Port the server will listen on.") + flags.StringVarP(&opts.Cfg.Address, "address", "a", "127.0.0.1", "Address of the interface the server will listen on.") + flags.IntVarP(&opts.Cfg.Port, "port", "p", 5000, "Port the server will listen on.") - flags.StringVar(&cmd.tools_file, "tools_file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.") + flags.StringVar(&opts.ToolsFile, "tools_file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.") // deprecate tools_file _ = flags.MarkDeprecated("tools_file", "please use --tools-file instead") - persistentFlags.StringVar(&cmd.tools_file, "tools-file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.") - persistentFlags.StringSliceVar(&cmd.tools_files, "tools-files", []string{}, "Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --tools-file, or --tools-folder.") - persistentFlags.StringVar(&cmd.tools_folder, "tools-folder", "", "Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --tools-file, or --tools-files.") - persistentFlags.Var(&cmd.cfg.LogLevel, "log-level", "Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'.") - persistentFlags.Var(&cmd.cfg.LoggingFormat, "logging-format", "Specify logging format to use. Allowed: 'standard' or 'JSON'.") - persistentFlags.BoolVar(&cmd.cfg.TelemetryGCP, "telemetry-gcp", false, "Enable exporting directly to Google Cloud Monitoring.") - persistentFlags.StringVar(&cmd.cfg.TelemetryOTLP, "telemetry-otlp", "", "Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318')") - persistentFlags.StringVar(&cmd.cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.") - // Fetch prebuilt tools sources to customize the help description - prebuiltHelp := fmt.Sprintf( - "Use a prebuilt tool configuration by source type. Allowed: '%s'. Can be specified multiple times.", - strings.Join(prebuiltconfigs.GetPrebuiltSources(), "', '"), - ) - persistentFlags.StringSliceVar(&cmd.prebuiltConfigs, "prebuilt", []string{}, prebuiltHelp) - flags.BoolVar(&cmd.cfg.Stdio, "stdio", false, "Listens via MCP STDIO instead of acting as a remote HTTP server.") - flags.BoolVar(&cmd.cfg.DisableReload, "disable-reload", false, "Disables dynamic reloading of tools file.") - flags.BoolVar(&cmd.cfg.UI, "ui", false, "Launches the Toolbox UI web server.") + flags.BoolVar(&opts.Cfg.Stdio, "stdio", false, "Listens via MCP STDIO instead of acting as a remote HTTP server.") + flags.BoolVar(&opts.Cfg.DisableReload, "disable-reload", false, "Disables dynamic reloading of tools file.") + flags.BoolVar(&opts.Cfg.UI, "ui", false, "Launches the Toolbox UI web server.") // TODO: Insecure by default. Might consider updating this for v1.0.0 - flags.StringSliceVar(&cmd.cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.") - flags.StringSliceVar(&cmd.cfg.AllowedHosts, "allowed-hosts", []string{"*"}, "Specifies a list of hosts permitted to access this server. Defaults to '*'.") - persistentFlags.StringSliceVar(&cmd.cfg.UserAgentMetadata, "user-agent-metadata", []string{}, "Appends additional metadata to the User-Agent.") + flags.StringSliceVar(&opts.Cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.") + flags.StringSliceVar(&opts.Cfg.AllowedHosts, "allowed-hosts", []string{"*"}, "Specifies a list of hosts permitted to access this server. Defaults to '*'.") // wrap RunE command so that we have access to original Command object - cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) } + cmd.RunE = func(*cobra.Command, []string) error { return run(cmd, opts) } // Register subcommands for tool invocation - baseCmd.AddCommand(invoke.NewCommand(cmd)) + cmd.AddCommand(invoke.NewCommand(opts)) // Register subcommands for skill generation - baseCmd.AddCommand(skills.NewCommand(cmd)) + cmd.AddCommand(skills.NewCommand(opts)) return cmd } -type ToolsFile struct { - Sources server.SourceConfigs `yaml:"sources"` - AuthServices server.AuthServiceConfigs `yaml:"authServices"` - EmbeddingModels server.EmbeddingModelConfigs `yaml:"embeddingModels"` - Tools server.ToolConfigs `yaml:"tools"` - Toolsets server.ToolsetConfigs `yaml:"toolsets"` - Prompts server.PromptConfigs `yaml:"prompts"` -} - -// parseEnv replaces environment variables ${ENV_NAME} with their values. -// also support ${ENV_NAME:default_value}. -func parseEnv(input string) (string, error) { - re := regexp.MustCompile(`\$\{(\w+)(:([^}]*))?\}`) - - var err error - output := re.ReplaceAllStringFunc(input, func(match string) string { - parts := re.FindStringSubmatch(match) - - // extract the variable name - variableName := parts[1] - if value, found := os.LookupEnv(variableName); found { - return value - } - if len(parts) >= 4 && parts[2] != "" { - return parts[3] - } - err = fmt.Errorf("environment variable not found: %q", variableName) - return "" - }) - return output, err -} - -func convertToolsFile(raw []byte) ([]byte, error) { - var input yaml.MapSlice - decoder := yaml.NewDecoder(bytes.NewReader(raw), yaml.UseOrderedMap()) - - // convert to tools file v2 - var buf bytes.Buffer - encoder := yaml.NewEncoder(&buf) - - v1keys := []string{"sources", "authSources", "authServices", "embeddingModels", "tools", "toolsets", "prompts"} - for { - if err := decoder.Decode(&input); err != nil { - if err == io.EOF { - break - } - return nil, err - } - for _, item := range input { - key, ok := item.Key.(string) - if !ok { - return nil, fmt.Errorf("unexpected non-string key in input: %v", item.Key) - } - // check if the key is config file v1's key - if slices.Contains(v1keys, key) { - // check if value conversion to yaml.MapSlice successfully - // fields such as "tools" in toolsets might pass the first check but - // fail to convert to MapSlice - if slice, ok := item.Value.(yaml.MapSlice); ok { - // Deprecated: convert authSources to authServices - if key == "authSources" { - key = "authServices" - } - transformed, err := transformDocs(key, slice) - if err != nil { - return nil, err - } - // encode per-doc - for _, doc := range transformed { - if err := encoder.Encode(doc); err != nil { - return nil, err - } - } - } else { - // invalid input will be ignored - // we don't want to throw error here since the config could - // be valid but with a different order such as: - // --- - // tools: - // - tool_a - // kind: toolsets - // --- - continue - } - } else { - // this doc is already v2, encode to buf - if err := encoder.Encode(input); err != nil { - return nil, err - } - break - } - } - } - return buf.Bytes(), nil -} - -// transformDocs transforms the configuration file from v1 format to v2 -// yaml.MapSlice will preserve the order in a map -func transformDocs(kind string, input yaml.MapSlice) ([]yaml.MapSlice, error) { - var transformed []yaml.MapSlice - for _, entry := range input { - entryName, ok := entry.Key.(string) - if !ok { - return nil, fmt.Errorf("unexpected non-string key for entry in '%s': %v", kind, entry.Key) - } - entryBody := ProcessValue(entry.Value, kind == "toolsets") - - currentTransformed := yaml.MapSlice{ - {Key: "kind", Value: kind}, - {Key: "name", Value: entryName}, - } - - // Merge the transformed body into our result - if bodySlice, ok := entryBody.(yaml.MapSlice); ok { - currentTransformed = append(currentTransformed, bodySlice...) - } else { - return nil, fmt.Errorf("unable to convert entryBody to MapSlice") - } - transformed = append(transformed, currentTransformed) - } - return transformed, nil -} - -// ProcessValue recursively looks for MapSlices to rename 'kind' -> 'type' -func ProcessValue(v any, isToolset bool) any { - switch val := v.(type) { - case yaml.MapSlice: - // creating a new MapSlice is safer for recursive transformation - newVal := make(yaml.MapSlice, len(val)) - for i, item := range val { - // Perform renaming - if item.Key == "kind" { - item.Key = "type" - } - // Recursive call for nested values (e.g., nested objects or lists) - item.Value = ProcessValue(item.Value, false) - newVal[i] = item - } - return newVal - case []any: - // Process lists: If it's a toolset top-level list, wrap it. - if isToolset { - return yaml.MapSlice{{Key: "tools", Value: val}} - } - // Otherwise, recurse into list items (to catch nested objects) - newVal := make([]any, len(val)) - for i := range val { - newVal[i] = ProcessValue(val[i], false) - } - return newVal - default: - return val - } -} - -// parseToolsFile parses the provided yaml into appropriate configs. -func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) { - var toolsFile ToolsFile - // Replace environment variables if found - output, err := parseEnv(string(raw)) - if err != nil { - return toolsFile, fmt.Errorf("error parsing environment variables: %s", err) - } - raw = []byte(output) - - raw, err = convertToolsFile(raw) - if err != nil { - return toolsFile, fmt.Errorf("error converting tools file: %s", err) - } - - // Parse contents - toolsFile.Sources, toolsFile.AuthServices, toolsFile.EmbeddingModels, toolsFile.Tools, toolsFile.Toolsets, toolsFile.Prompts, err = server.UnmarshalResourceConfig(ctx, raw) - if err != nil { - return toolsFile, err - } - return toolsFile, nil -} - -// mergeToolsFiles merges multiple ToolsFile structs into one. -// Detects and raises errors for resource conflicts in sources, authServices, tools, and toolsets. -// All resource names (sources, authServices, tools, toolsets) must be unique across all files. -func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) { - merged := ToolsFile{ - Sources: make(server.SourceConfigs), - AuthServices: make(server.AuthServiceConfigs), - EmbeddingModels: make(server.EmbeddingModelConfigs), - Tools: make(server.ToolConfigs), - Toolsets: make(server.ToolsetConfigs), - Prompts: make(server.PromptConfigs), - } - - var conflicts []string - - for fileIndex, file := range files { - // Check for conflicts and merge sources - for name, source := range file.Sources { - if _, exists := merged.Sources[name]; exists { - conflicts = append(conflicts, fmt.Sprintf("source '%s' (file #%d)", name, fileIndex+1)) - } else { - merged.Sources[name] = source - } - } - - // Check for conflicts and merge authServices - for name, authService := range file.AuthServices { - if _, exists := merged.AuthServices[name]; exists { - conflicts = append(conflicts, fmt.Sprintf("authService '%s' (file #%d)", name, fileIndex+1)) - } else { - merged.AuthServices[name] = authService - } - } - - // Check for conflicts and merge embeddingModels - for name, em := range file.EmbeddingModels { - if _, exists := merged.EmbeddingModels[name]; exists { - conflicts = append(conflicts, fmt.Sprintf("embedding model '%s' (file #%d)", name, fileIndex+1)) - } else { - merged.EmbeddingModels[name] = em - } - } - - // Check for conflicts and merge tools - for name, tool := range file.Tools { - if _, exists := merged.Tools[name]; exists { - conflicts = append(conflicts, fmt.Sprintf("tool '%s' (file #%d)", name, fileIndex+1)) - } else { - merged.Tools[name] = tool - } - } - - // Check for conflicts and merge toolsets - for name, toolset := range file.Toolsets { - if _, exists := merged.Toolsets[name]; exists { - conflicts = append(conflicts, fmt.Sprintf("toolset '%s' (file #%d)", name, fileIndex+1)) - } else { - merged.Toolsets[name] = toolset - } - } - - // Check for conflicts and merge prompts - for name, prompt := range file.Prompts { - if _, exists := merged.Prompts[name]; exists { - conflicts = append(conflicts, fmt.Sprintf("prompt '%s' (file #%d)", name, fileIndex+1)) - } else { - merged.Prompts[name] = prompt - } - } - } - - // If conflicts were detected, return an error - if len(conflicts) > 0 { - return ToolsFile{}, fmt.Errorf("resource conflicts detected:\n - %s\n\nPlease ensure each source, authService, tool, toolset and prompt has a unique name across all files", strings.Join(conflicts, "\n - ")) - } - - return merged, nil -} - -// loadAndMergeToolsFiles loads multiple YAML files and merges them -func loadAndMergeToolsFiles(ctx context.Context, filePaths []string) (ToolsFile, error) { - var toolsFiles []ToolsFile - - for _, filePath := range filePaths { - buf, err := os.ReadFile(filePath) - if err != nil { - return ToolsFile{}, fmt.Errorf("unable to read tool file at %q: %w", filePath, err) - } - - toolsFile, err := parseToolsFile(ctx, buf) - if err != nil { - return ToolsFile{}, fmt.Errorf("unable to parse tool file at %q: %w", filePath, err) - } - - toolsFiles = append(toolsFiles, toolsFile) - } - - mergedFile, err := mergeToolsFiles(toolsFiles...) - if err != nil { - return ToolsFile{}, fmt.Errorf("unable to merge tools files: %w", err) - } - - return mergedFile, nil -} - -// loadAndMergeToolsFolder loads all YAML files from a directory and merges them -func loadAndMergeToolsFolder(ctx context.Context, folderPath string) (ToolsFile, error) { - // Check if directory exists - info, err := os.Stat(folderPath) - if err != nil { - return ToolsFile{}, fmt.Errorf("unable to access tools folder at %q: %w", folderPath, err) - } - if !info.IsDir() { - return ToolsFile{}, fmt.Errorf("path %q is not a directory", folderPath) - } - - // Find all YAML files in the directory - pattern := filepath.Join(folderPath, "*.yaml") - yamlFiles, err := filepath.Glob(pattern) - if err != nil { - return ToolsFile{}, fmt.Errorf("error finding YAML files in %q: %w", folderPath, err) - } - - // Also find .yml files - ymlPattern := filepath.Join(folderPath, "*.yml") - ymlFiles, err := filepath.Glob(ymlPattern) - if err != nil { - return ToolsFile{}, fmt.Errorf("error finding YML files in %q: %w", folderPath, err) - } - - // Combine both file lists - allFiles := append(yamlFiles, ymlFiles...) - - if len(allFiles) == 0 { - return ToolsFile{}, fmt.Errorf("no YAML files found in directory %q", folderPath) - } - - // Use existing loadAndMergeToolsFiles function - return loadAndMergeToolsFiles(ctx, allFiles) -} - -func handleDynamicReload(ctx context.Context, toolsFile ToolsFile, s *server.Server) error { +func handleDynamicReload(ctx context.Context, toolsFile cli.ToolsFile, s *server.Server) error { logger, err := util.LoggerFromContext(ctx) if err != nil { panic(err) @@ -748,7 +151,7 @@ func handleDynamicReload(ctx context.Context, toolsFile ToolsFile, s *server.Ser // validateReloadEdits checks that the reloaded tools file configs can initialized without failing func validateReloadEdits( - ctx context.Context, toolsFile ToolsFile, + ctx context.Context, toolsFile cli.ToolsFile, ) (map[string]sources.Source, map[string]auth.AuthService, map[string]embeddingmodels.EmbeddingModel, map[string]tools.Tool, map[string]tools.Toolset, map[string]prompts.Prompt, map[string]prompts.Promptset, error, ) { logger, err := util.LoggerFromContext(ctx) @@ -874,18 +277,18 @@ func watchChanges(ctx context.Context, watchDirs map[string]bool, watchedFiles m case <-debounce.C: debounce.Stop() - var reloadedToolsFile ToolsFile + var reloadedToolsFile cli.ToolsFile if watchingFolder { logger.DebugContext(ctx, "Reloading tools folder.") - reloadedToolsFile, err = loadAndMergeToolsFolder(ctx, folderToWatch) + reloadedToolsFile, err = cli.LoadAndMergeToolsFolder(ctx, folderToWatch) if err != nil { logger.WarnContext(ctx, "error loading tools folder %s", err) continue } } else { logger.DebugContext(ctx, "Reloading tools file(s).") - reloadedToolsFile, err = loadAndMergeToolsFiles(ctx, slices.Collect(maps.Keys(watchedFiles))) + reloadedToolsFile, err = cli.LoadAndMergeToolsFiles(ctx, slices.Collect(maps.Keys(watchedFiles))) if err != nil { logger.WarnContext(ctx, "error loading tools files %s", err) continue @@ -929,184 +332,7 @@ func resolveWatcherInputs(toolsFile string, toolsFiles []string, toolsFolder str return watchDirs, watchedFiles } -func (cmd *Command) Config() server.ServerConfig { - return cmd.cfg -} - -func (cmd *Command) Out() io.Writer { - return cmd.outStream -} - -func (cmd *Command) Logger() log.Logger { - return cmd.logger -} - -func (cmd *Command) LoadConfig(ctx context.Context) error { - logger, err := util.LoggerFromContext(ctx) - if err != nil { - return err - } - - var allToolsFiles []ToolsFile - - // Load Prebuilt Configuration - - if len(cmd.prebuiltConfigs) > 0 { - slices.Sort(cmd.prebuiltConfigs) - sourcesList := strings.Join(cmd.prebuiltConfigs, ", ") - logMsg := fmt.Sprintf("Using prebuilt tool configurations for: %s", sourcesList) - logger.InfoContext(ctx, logMsg) - - for _, configName := range cmd.prebuiltConfigs { - buf, err := prebuiltconfigs.Get(configName) - if err != nil { - logger.ErrorContext(ctx, err.Error()) - return err - } - - // Parse into ToolsFile struct - parsed, err := parseToolsFile(ctx, buf) - if err != nil { - errMsg := fmt.Errorf("unable to parse prebuilt tool configuration for '%s': %w", configName, err) - logger.ErrorContext(ctx, errMsg.Error()) - return errMsg - } - allToolsFiles = append(allToolsFiles, parsed) - } - } - - // Determine if Custom Files should be loaded - // Check for explicit custom flags - isCustomConfigured := cmd.tools_file != "" || len(cmd.tools_files) > 0 || cmd.tools_folder != "" - - // Determine if default 'tools.yaml' should be used (No prebuilt AND No custom flags) - useDefaultToolsFile := len(cmd.prebuiltConfigs) == 0 && !isCustomConfigured - - if useDefaultToolsFile { - cmd.tools_file = "tools.yaml" - isCustomConfigured = true - } - - // Load Custom Configurations - if isCustomConfigured { - // Enforce exclusivity among custom flags (tools-file vs tools-files vs tools-folder) - if (cmd.tools_file != "" && len(cmd.tools_files) > 0) || - (cmd.tools_file != "" && cmd.tools_folder != "") || - (len(cmd.tools_files) > 0 && cmd.tools_folder != "") { - errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously") - logger.ErrorContext(ctx, errMsg.Error()) - return errMsg - } - - var customTools ToolsFile - var err error - - if len(cmd.tools_files) > 0 { - // Use tools-files - logger.InfoContext(ctx, fmt.Sprintf("Loading and merging %d tool configuration files", len(cmd.tools_files))) - customTools, err = loadAndMergeToolsFiles(ctx, cmd.tools_files) - } else if cmd.tools_folder != "" { - // Use tools-folder - logger.InfoContext(ctx, fmt.Sprintf("Loading and merging all YAML files from directory: %s", cmd.tools_folder)) - customTools, err = loadAndMergeToolsFolder(ctx, cmd.tools_folder) - } else { - // Use single file (tools-file or default `tools.yaml`) - buf, readFileErr := os.ReadFile(cmd.tools_file) - if readFileErr != nil { - errMsg := fmt.Errorf("unable to read tool file at %q: %w", cmd.tools_file, readFileErr) - logger.ErrorContext(ctx, errMsg.Error()) - return errMsg - } - customTools, err = parseToolsFile(ctx, buf) - if err != nil { - err = fmt.Errorf("unable to parse tool file at %q: %w", cmd.tools_file, err) - } - } - - if err != nil { - logger.ErrorContext(ctx, err.Error()) - return err - } - allToolsFiles = append(allToolsFiles, customTools) - } - - // Modify version string based on loaded configurations - if len(cmd.prebuiltConfigs) > 0 { - tag := "prebuilt" - if isCustomConfigured { - tag = "custom" - } - // cmd.prebuiltConfigs is already sorted above - for _, configName := range cmd.prebuiltConfigs { - cmd.cfg.Version += fmt.Sprintf("+%s.%s", tag, configName) - } - } - - // Merge Everything - // This will error if custom tools collide with prebuilt tools - finalToolsFile, err := mergeToolsFiles(allToolsFiles...) - if err != nil { - logger.ErrorContext(ctx, err.Error()) - return err - } - - cmd.cfg.SourceConfigs = finalToolsFile.Sources - cmd.cfg.AuthServiceConfigs = finalToolsFile.AuthServices - cmd.cfg.EmbeddingModelConfigs = finalToolsFile.EmbeddingModels - cmd.cfg.ToolConfigs = finalToolsFile.Tools - cmd.cfg.ToolsetConfigs = finalToolsFile.Toolsets - cmd.cfg.PromptConfigs = finalToolsFile.Prompts - - return nil -} - -func (cmd *Command) Setup(ctx context.Context) (context.Context, func(context.Context) error, error) { - // If stdio, set logger's out stream (usually DEBUG and INFO logs) to errStream - loggerOut := cmd.outStream - if cmd.cfg.Stdio { - loggerOut = cmd.errStream - } - - // Handle logger separately from config - logger, err := log.NewLogger(cmd.cfg.LoggingFormat.String(), cmd.cfg.LogLevel.String(), loggerOut, cmd.errStream) - if err != nil { - return ctx, nil, fmt.Errorf("unable to initialize logger: %w", err) - } - cmd.logger = logger - - ctx = util.WithLogger(ctx, cmd.logger) - - // Set up OpenTelemetry - otelShutdown, err := telemetry.SetupOTel(ctx, cmd.cfg.Version, cmd.cfg.TelemetryOTLP, cmd.cfg.TelemetryGCP, cmd.cfg.TelemetryServiceName) - if err != nil { - errMsg := fmt.Errorf("error setting up OpenTelemetry: %w", err) - cmd.logger.ErrorContext(ctx, errMsg.Error()) - return ctx, nil, errMsg - } - - shutdownFunc := func(ctx context.Context) error { - err := otelShutdown(ctx) - if err != nil { - errMsg := fmt.Errorf("error shutting down OpenTelemetry: %w", err) - cmd.logger.ErrorContext(ctx, errMsg.Error()) - return err - } - return nil - } - - instrumentation, err := telemetry.CreateTelemetryInstrumentation(cmd.cfg.Version) - if err != nil { - errMsg := fmt.Errorf("unable to create telemetry instrumentation: %w", err) - cmd.logger.ErrorContext(ctx, errMsg.Error()) - return ctx, shutdownFunc, errMsg - } - - ctx = util.WithInstrumentation(ctx, instrumentation) - - return ctx, shutdownFunc, nil -} - -func run(cmd *Command) error { +func run(cmd *cobra.Command, opts *cli.ToolboxOptions) error { ctx, cancel := context.WithCancel(cmd.Context()) defer cancel() @@ -1123,14 +349,14 @@ func run(cmd *Command) error { } switch s { case syscall.SIGINT: - cmd.logger.DebugContext(sCtx, "Received SIGINT signal to shutdown.") + opts.Logger.DebugContext(sCtx, "Received SIGINT signal to shutdown.") case syscall.SIGTERM: - cmd.logger.DebugContext(sCtx, "Sending SIGTERM signal to shutdown.") + opts.Logger.DebugContext(sCtx, "Sending SIGTERM signal to shutdown.") } cancel() }(ctx) - ctx, shutdown, err := cmd.Setup(ctx) + ctx, shutdown, err := opts.Setup(ctx) if err != nil { return err } @@ -1138,24 +364,25 @@ func run(cmd *Command) error { _ = shutdown(ctx) }() - if err := cmd.LoadConfig(ctx); err != nil { + isCustomConfigured, err := opts.LoadConfig(ctx) + if err != nil { return err } // start server - s, err := server.NewServer(ctx, cmd.cfg) + s, err := server.NewServer(ctx, opts.Cfg) if err != nil { errMsg := fmt.Errorf("toolbox failed to initialize: %w", err) - cmd.logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } // run server in background srvErr := make(chan error) - if cmd.cfg.Stdio { + if opts.Cfg.Stdio { go func() { defer close(srvErr) - err = s.ServeStdio(ctx, cmd.inStream, cmd.outStream) + err = s.ServeStdio(ctx, opts.IOStreams.In, opts.IOStreams.Out) if err != nil { srvErr <- err } @@ -1164,12 +391,12 @@ func run(cmd *Command) error { err = s.Listen(ctx) if err != nil { errMsg := fmt.Errorf("toolbox failed to start listener: %w", err) - cmd.logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } - cmd.logger.InfoContext(ctx, "Server ready to serve!") - if cmd.cfg.UI { - cmd.logger.InfoContext(ctx, fmt.Sprintf("Toolbox UI is up and running at: http://%s:%d/ui", cmd.cfg.Address, cmd.cfg.Port)) + opts.Logger.InfoContext(ctx, "Server ready to serve!") + if opts.Cfg.UI { + opts.Logger.InfoContext(ctx, fmt.Sprintf("Toolbox UI is up and running at: http://%s:%d/ui", opts.Cfg.Address, opts.Cfg.Port)) } go func() { @@ -1181,11 +408,8 @@ func run(cmd *Command) error { }() } - // Determine if Custom Files are configured (re-check as loadAndMergeConfig might have updated defaults) - isCustomConfigured := cmd.tools_file != "" || len(cmd.tools_files) > 0 || cmd.tools_folder != "" - - if isCustomConfigured && !cmd.cfg.DisableReload { - watchDirs, watchedFiles := resolveWatcherInputs(cmd.tools_file, cmd.tools_files, cmd.tools_folder) + if isCustomConfigured && !opts.Cfg.DisableReload { + watchDirs, watchedFiles := resolveWatcherInputs(opts.ToolsFile, opts.ToolsFiles, opts.ToolsFolder) // start watching the file(s) or folder for changes to trigger dynamic reloading go watchChanges(ctx, watchDirs, watchedFiles, s) } @@ -1195,13 +419,13 @@ func run(cmd *Command) error { case err := <-srvErr: if err != nil { errMsg := fmt.Errorf("toolbox crashed with the following error: %w", err) - cmd.logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } case <-ctx.Done(): shutdownContext, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cmd.logger.WarnContext(shutdownContext, "Shutting down gracefully...") + opts.Logger.WarnContext(shutdownContext, "Shutting down gracefully...") err := s.Shutdown(shutdownContext) if err == context.DeadlineExceeded { return fmt.Errorf("graceful shutdown timed out... forcing exit") diff --git a/cmd/root_test.go b/cmd/root_test.go index f26bd1706a..cd04d7de17 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -31,22 +31,12 @@ import ( "github.com/google/go-cmp/cmp" - "github.com/googleapis/genai-toolbox/internal/auth/google" - "github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini" + "github.com/googleapis/genai-toolbox/internal/cli" "github.com/googleapis/genai-toolbox/internal/log" - "github.com/googleapis/genai-toolbox/internal/prebuiltconfigs" - "github.com/googleapis/genai-toolbox/internal/prompts" - "github.com/googleapis/genai-toolbox/internal/prompts/custom" "github.com/googleapis/genai-toolbox/internal/server" - cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http" "github.com/googleapis/genai-toolbox/internal/telemetry" "github.com/googleapis/genai-toolbox/internal/testutils" - "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/http" - "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql" "github.com/googleapis/genai-toolbox/internal/util" - "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/spf13/cobra" ) @@ -76,15 +66,16 @@ func withDefaults(c server.ServerConfig) server.ServerConfig { return c } -func invokeCommand(args []string) (*Command, string, error) { - c := NewCommand() +func invokeCommand(args []string) (*cobra.Command, *cli.ToolboxOptions, string, error) { + buf := new(bytes.Buffer) + opts := cli.NewToolboxOptions(cli.WithIOStreams(buf, buf)) + c := NewCommand(opts) // Keep the test output quiet c.SilenceUsage = true c.SilenceErrors = true // Capture output - buf := new(bytes.Buffer) c.SetOut(buf) c.SetErr(buf) c.SetArgs(args) @@ -96,22 +87,23 @@ func invokeCommand(args []string) (*Command, string, error) { err := c.Execute() - return c, buf.String(), err + return c, opts, buf.String(), err } // invokeCommandWithContext executes the command with a context and returns the captured output. -func invokeCommandWithContext(ctx context.Context, args []string) (*Command, string, error) { - // Capture output using a buffer +func invokeCommandWithContext(ctx context.Context, args []string) (*cobra.Command, *cli.ToolboxOptions, string, error) { buf := new(bytes.Buffer) - c := NewCommand(WithStreams(buf, buf)) + opts := cli.NewToolboxOptions(cli.WithIOStreams(buf, buf)) + c := NewCommand(opts) + // Capture output using a buffer c.SetArgs(args) c.SilenceUsage = true c.SilenceErrors = true c.SetContext(ctx) err := c.Execute() - return c, buf.String(), err + return c, opts, buf.String(), err } func TestVersion(t *testing.T) { @@ -121,7 +113,7 @@ func TestVersion(t *testing.T) { } want := strings.TrimSpace(string(data)) - _, got, err := invokeCommand([]string{"--version"}) + _, _, got, err := invokeCommand([]string{"--version"}) if err != nil { t.Fatalf("error invoking command: %s", err) } @@ -243,79 +235,13 @@ func TestServerConfigFlags(t *testing.T) { } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - c, _, err := invokeCommand(tc.args) + _, opts, _, err := invokeCommand(tc.args) if err != nil { t.Fatalf("unexpected error invoking command: %s", err) } - if !cmp.Equal(c.cfg, tc.want) { - t.Fatalf("got %v, want %v", c.cfg, tc.want) - } - }) - } -} - -func TestParseEnv(t *testing.T) { - tcs := []struct { - desc string - env map[string]string - in string - want string - err bool - errString string - }{ - { - desc: "without default without env", - in: "${FOO}", - want: "", - err: true, - errString: `environment variable not found: "FOO"`, - }, - { - desc: "without default with env", - env: map[string]string{ - "FOO": "bar", - }, - in: "${FOO}", - want: "bar", - }, - { - desc: "with empty default", - in: "${FOO:}", - want: "", - }, - { - desc: "with default", - in: "${FOO:bar}", - want: "bar", - }, - { - desc: "with default with env", - env: map[string]string{ - "FOO": "hello", - }, - in: "${FOO:bar}", - want: "hello", - }, - } - for _, tc := range tcs { - t.Run(tc.desc, func(t *testing.T) { - if tc.env != nil { - for k, v := range tc.env { - t.Setenv(k, v) - } - } - got, err := parseEnv(tc.in) - if tc.err { - if err == nil { - t.Fatalf("expected error not found") - } - if tc.errString != err.Error() { - t.Fatalf("incorrect error string: got %s, want %s", err, tc.errString) - } - } - if tc.want != got { - t.Fatalf("unexpected want: got %s, want %s", got, tc.want) + if !cmp.Equal(opts.Cfg, tc.want) { + t.Fatalf("got %v, want %v", opts.Cfg, tc.want) } }) } @@ -350,12 +276,12 @@ func TestToolFileFlag(t *testing.T) { } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - c, _, err := invokeCommand(tc.args) + _, opts, _, err := invokeCommand(tc.args) if err != nil { t.Fatalf("unexpected error invoking command: %s", err) } - if c.tools_file != tc.want { - t.Fatalf("got %v, want %v", c.cfg, tc.want) + if opts.ToolsFile != tc.want { + t.Fatalf("got %v, want %v", opts.Cfg, tc.want) } }) } @@ -385,12 +311,12 @@ func TestToolsFilesFlag(t *testing.T) { } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - c, _, err := invokeCommand(tc.args) + _, opts, _, err := invokeCommand(tc.args) if err != nil { t.Fatalf("unexpected error invoking command: %s", err) } - if diff := cmp.Diff(c.tools_files, tc.want); diff != "" { - t.Fatalf("got %v, want %v", c.tools_files, tc.want) + if diff := cmp.Diff(opts.ToolsFiles, tc.want); diff != "" { + t.Fatalf("got %v, want %v", opts.ToolsFiles, tc.want) } }) } @@ -415,12 +341,12 @@ func TestToolsFolderFlag(t *testing.T) { } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - c, _, err := invokeCommand(tc.args) + _, opts, _, err := invokeCommand(tc.args) if err != nil { t.Fatalf("unexpected error invoking command: %s", err) } - if c.tools_folder != tc.want { - t.Fatalf("got %v, want %v", c.tools_folder, tc.want) + if opts.ToolsFolder != tc.want { + t.Fatalf("got %v, want %v", opts.ToolsFolder, tc.want) } }) } @@ -455,12 +381,12 @@ func TestPrebuiltFlag(t *testing.T) { } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - c, _, err := invokeCommand(tc.args) + _, opts, _, err := invokeCommand(tc.args) if err != nil { t.Fatalf("unexpected error invoking command: %s", err) } - if diff := cmp.Diff(c.prebuiltConfigs, tc.want); diff != "" { - t.Fatalf("got %v, want %v, diff %s", c.prebuiltConfigs, tc.want, diff) + if diff := cmp.Diff(opts.PrebuiltConfigs, tc.want); diff != "" { + t.Fatalf("got %v, want %v, diff %s", opts.PrebuiltConfigs, tc.want, diff) } }) } @@ -482,7 +408,7 @@ func TestFailServerConfigFlags(t *testing.T) { } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - _, _, err := invokeCommand(tc.args) + _, _, _, err := invokeCommand(tc.args) if err == nil { t.Fatalf("expected an error, but got nil") } @@ -491,11 +417,11 @@ func TestFailServerConfigFlags(t *testing.T) { } func TestDefaultLoggingFormat(t *testing.T) { - c, _, err := invokeCommand([]string{}) + _, opts, _, err := invokeCommand([]string{}) if err != nil { t.Fatalf("unexpected error invoking command: %s", err) } - got := c.cfg.LoggingFormat.String() + got := opts.Cfg.LoggingFormat.String() want := "standard" if got != want { t.Fatalf("unexpected default logging format flag: got %v, want %v", got, want) @@ -503,1377 +429,17 @@ func TestDefaultLoggingFormat(t *testing.T) { } func TestDefaultLogLevel(t *testing.T) { - c, _, err := invokeCommand([]string{}) + _, opts, _, err := invokeCommand([]string{}) if err != nil { t.Fatalf("unexpected error invoking command: %s", err) } - got := c.cfg.LogLevel.String() + got := opts.Cfg.LogLevel.String() want := "info" if got != want { t.Fatalf("unexpected default log level flag: got %v, want %v", got, want) } } -func TestConvertToolsFile(t *testing.T) { - tcs := []struct { - desc string - in string - want string - isErr bool - errStr string - }{ - { - desc: "basic convert", - in: ` - sources: - my-pg-instance: - kind: cloud-sql-postgres - project: my-project - region: my-region - instance: my-instance - database: my_db - user: my_user - password: my_pass - authServices: - my-google-auth: - kind: google - clientId: testing-id - tools: - example_tool: - kind: postgres-sql - source: my-pg-instance - description: some description - statement: SELECT * FROM SQL_STATEMENT; - parameters: - - name: country - type: string - description: some description - toolsets: - example_toolset: - - example_tool - prompts: - code_review: - description: ask llm to analyze code quality - messages: - - content: "please review the following code for quality: {{.code}}" - arguments: - - name: code - description: the code to review - embeddingModels: - gemini-model: - kind: gemini - model: gemini-embedding-001 - apiKey: some-key - dimension: 768`, - want: `kind: sources -name: my-pg-instance -type: cloud-sql-postgres -project: my-project -region: my-region -instance: my-instance -database: my_db -user: my_user -password: my_pass ---- -kind: authServices -name: my-google-auth -type: google -clientId: testing-id ---- -kind: tools -name: example_tool -type: postgres-sql -source: my-pg-instance -description: some description -statement: SELECT * FROM SQL_STATEMENT; -parameters: -- name: country - type: string - description: some description ---- -kind: toolsets -name: example_toolset -tools: -- example_tool ---- -kind: prompts -name: code_review -description: ask llm to analyze code quality -messages: -- content: "please review the following code for quality: {{.code}}" -arguments: -- name: code - description: the code to review ---- -kind: embeddingModels -name: gemini-model -type: gemini -model: gemini-embedding-001 -apiKey: some-key -dimension: 768 -`, - }, - { - desc: "preserve resource order", - in: ` - tools: - example_tool: - kind: postgres-sql - source: my-pg-instance - description: some description - statement: SELECT * FROM SQL_STATEMENT; - parameters: - - name: country - type: string - description: some description - sources: - my-pg-instance: - kind: cloud-sql-postgres - project: my-project - region: my-region - instance: my-instance - database: my_db - user: my_user - password: my_pass - authServices: - my-google-auth: - kind: google - clientId: testing-id - toolsets: - example_toolset: - - example_tool - authSources: - my-google-auth2: - kind: google - clientId: testing-id`, - want: `kind: tools -name: example_tool -type: postgres-sql -source: my-pg-instance -description: some description -statement: SELECT * FROM SQL_STATEMENT; -parameters: -- name: country - type: string - description: some description ---- -kind: sources -name: my-pg-instance -type: cloud-sql-postgres -project: my-project -region: my-region -instance: my-instance -database: my_db -user: my_user -password: my_pass ---- -kind: authServices -name: my-google-auth -type: google -clientId: testing-id ---- -kind: toolsets -name: example_toolset -tools: -- example_tool ---- -kind: authServices -name: my-google-auth2 -type: google -clientId: testing-id -`, - }, - { - desc: "convert combination of v1 and v2", - in: ` - sources: - my-pg-instance: - kind: cloud-sql-postgres - project: my-project - region: my-region - instance: my-instance - database: my_db - user: my_user - password: my_pass - authServices: - my-google-auth: - kind: google - clientId: testing-id - tools: - example_tool: - kind: postgres-sql - source: my-pg-instance - description: some description - statement: SELECT * FROM SQL_STATEMENT; - parameters: - - name: country - type: string - description: some description - toolsets: - example_toolset: - - example_tool - prompts: - code_review: - description: ask llm to analyze code quality - messages: - - content: "please review the following code for quality: {{.code}}" - arguments: - - name: code - description: the code to review - embeddingModels: - gemini-model: - kind: gemini - model: gemini-embedding-001 - apiKey: some-key - dimension: 768 ---- - kind: sources - name: my-pg-instance2 - type: cloud-sql-postgres - project: my-project - region: my-region - instance: my-instance ---- - kind: authServices - name: my-google-auth2 - type: google - clientId: testing-id ---- - kind: tools - name: example_tool2 - type: postgres-sql - source: my-pg-instance - description: some description - statement: SELECT * FROM SQL_STATEMENT; - parameters: - - name: country - type: string - description: some description ---- - kind: toolsets - name: example_toolset2 - tools: - - example_tool ---- - tools: - - example_tool - kind: toolsets - name: example_toolset3 ---- - kind: prompts - name: code_review2 - description: ask llm to analyze code quality - messages: - - content: "please review the following code for quality: {{.code}}" - arguments: - - name: code - description: the code to review ---- - kind: embeddingModels - name: gemini-model2 - type: gemini`, - want: `kind: sources -name: my-pg-instance -type: cloud-sql-postgres -project: my-project -region: my-region -instance: my-instance -database: my_db -user: my_user -password: my_pass ---- -kind: authServices -name: my-google-auth -type: google -clientId: testing-id ---- -kind: tools -name: example_tool -type: postgres-sql -source: my-pg-instance -description: some description -statement: SELECT * FROM SQL_STATEMENT; -parameters: -- name: country - type: string - description: some description ---- -kind: toolsets -name: example_toolset -tools: -- example_tool ---- -kind: prompts -name: code_review -description: ask llm to analyze code quality -messages: -- content: "please review the following code for quality: {{.code}}" -arguments: -- name: code - description: the code to review ---- -kind: embeddingModels -name: gemini-model -type: gemini -model: gemini-embedding-001 -apiKey: some-key -dimension: 768 ---- -kind: sources -name: my-pg-instance2 -type: cloud-sql-postgres -project: my-project -region: my-region -instance: my-instance ---- -kind: authServices -name: my-google-auth2 -type: google -clientId: testing-id ---- -kind: tools -name: example_tool2 -type: postgres-sql -source: my-pg-instance -description: some description -statement: SELECT * FROM SQL_STATEMENT; -parameters: -- name: country - type: string - description: some description ---- -kind: toolsets -name: example_toolset2 -tools: -- example_tool ---- -tools: -- example_tool -kind: toolsets -name: example_toolset3 ---- -kind: prompts -name: code_review2 -description: ask llm to analyze code quality -messages: -- content: "please review the following code for quality: {{.code}}" -arguments: -- name: code - description: the code to review ---- -kind: embeddingModels -name: gemini-model2 -type: gemini -`, - }, - { - desc: "no convertion needed", - in: `kind: sources -name: my-pg-instance -type: cloud-sql-postgres -project: my-project -region: my-region -instance: my-instance -database: my_db -user: my_user -password: my_pass ---- -kind: tools -name: example_tool -type: postgres-sql -source: my-pg-instance -description: some description -statement: SELECT * FROM SQL_STATEMENT; -parameters: -- name: country - type: string - description: some description ---- -kind: toolsets -name: example_toolset -tools: -- example_tool`, - want: `kind: sources -name: my-pg-instance -type: cloud-sql-postgres -project: my-project -region: my-region -instance: my-instance -database: my_db -user: my_user -password: my_pass ---- -kind: tools -name: example_tool -type: postgres-sql -source: my-pg-instance -description: some description -statement: SELECT * FROM SQL_STATEMENT; -parameters: -- name: country - type: string - description: some description ---- -kind: toolsets -name: example_toolset -tools: -- example_tool -`, - }, - { - desc: "invalid source", - in: `sources: invalid`, - want: "", - }, - { - desc: "invalid toolset", - in: `toolsets: invalid`, - want: "", - }, - } - for _, tc := range tcs { - t.Run(tc.desc, func(t *testing.T) { - output, err := convertToolsFile([]byte(tc.in)) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - - if diff := cmp.Diff(string(output), tc.want); diff != "" { - t.Fatalf("incorrect toolsets parse: diff %v", diff) - } - }) - } -} - -func TestParseToolFile(t *testing.T) { - ctx, err := testutils.ContextWithNewLogger() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - tcs := []struct { - description string - in string - wantToolsFile ToolsFile - }{ - { - description: "basic example tools file v1", - in: ` - sources: - my-pg-instance: - kind: cloud-sql-postgres - project: my-project - region: my-region - instance: my-instance - database: my_db - user: my_user - password: my_pass - tools: - example_tool: - kind: postgres-sql - source: my-pg-instance - description: some description - statement: | - SELECT * FROM SQL_STATEMENT; - parameters: - - name: country - type: string - description: some description - toolsets: - example_toolset: - - example_tool - `, - wantToolsFile: ToolsFile{ - Sources: server.SourceConfigs{ - "my-pg-instance": cloudsqlpgsrc.Config{ - Name: "my-pg-instance", - Type: cloudsqlpgsrc.SourceType, - Project: "my-project", - Region: "my-region", - Instance: "my-instance", - IPType: "public", - Database: "my_db", - User: "my_user", - Password: "my_pass", - }, - }, - Tools: server.ToolConfigs{ - "example_tool": postgressql.Config{ - Name: "example_tool", - Type: "postgres-sql", - Source: "my-pg-instance", - Description: "some description", - Statement: "SELECT * FROM SQL_STATEMENT;\n", - Parameters: []parameters.Parameter{ - parameters.NewStringParameter("country", "some description"), - }, - AuthRequired: []string{}, - }, - }, - Toolsets: server.ToolsetConfigs{ - "example_toolset": tools.ToolsetConfig{ - Name: "example_toolset", - ToolNames: []string{"example_tool"}, - }, - }, - AuthServices: nil, - Prompts: nil, - }, - }, - { - description: "basic example tools file v2", - in: ` - kind: sources - name: my-pg-instance - type: cloud-sql-postgres - project: my-project - region: my-region - instance: my-instance - database: my_db - user: my_user - password: my_pass ---- - kind: authServices - name: my-google-auth - type: google - clientId: testing-id ---- - kind: embeddingModels - name: gemini-model - type: gemini - model: gemini-embedding-001 - apiKey: some-key - dimension: 768 ---- - kind: tools - name: example_tool - type: postgres-sql - source: my-pg-instance - description: some description - statement: | - SELECT * FROM SQL_STATEMENT; - parameters: - - name: country - type: string - description: some description ---- - kind: toolsets - name: example_toolset - tools: - - example_tool ---- - kind: prompts - name: code_review - description: ask llm to analyze code quality - messages: - - content: "please review the following code for quality: {{.code}}" - arguments: - - name: code - description: the code to review - `, - wantToolsFile: ToolsFile{ - Sources: server.SourceConfigs{ - "my-pg-instance": cloudsqlpgsrc.Config{ - Name: "my-pg-instance", - Type: cloudsqlpgsrc.SourceType, - Project: "my-project", - Region: "my-region", - Instance: "my-instance", - IPType: "public", - Database: "my_db", - User: "my_user", - Password: "my_pass", - }, - }, - AuthServices: server.AuthServiceConfigs{ - "my-google-auth": google.Config{ - Name: "my-google-auth", - Type: google.AuthServiceType, - ClientID: "testing-id", - }, - }, - EmbeddingModels: server.EmbeddingModelConfigs{ - "gemini-model": gemini.Config{ - Name: "gemini-model", - Type: gemini.EmbeddingModelType, - Model: "gemini-embedding-001", - ApiKey: "some-key", - Dimension: 768, - }, - }, - Tools: server.ToolConfigs{ - "example_tool": postgressql.Config{ - Name: "example_tool", - Type: "postgres-sql", - Source: "my-pg-instance", - Description: "some description", - Statement: "SELECT * FROM SQL_STATEMENT;\n", - Parameters: []parameters.Parameter{ - parameters.NewStringParameter("country", "some description"), - }, - AuthRequired: []string{}, - }, - }, - Toolsets: server.ToolsetConfigs{ - "example_toolset": tools.ToolsetConfig{ - Name: "example_toolset", - ToolNames: []string{"example_tool"}, - }, - }, - Prompts: server.PromptConfigs{ - "code_review": &custom.Config{ - Name: "code_review", - Description: "ask llm to analyze code quality", - Arguments: prompts.Arguments{ - {Parameter: parameters.NewStringParameter("code", "the code to review")}, - }, - Messages: []prompts.Message{ - {Role: "user", Content: "please review the following code for quality: {{.code}}"}, - }, - }, - }, - }, - }, - { - description: "only prompts", - in: ` - kind: prompts - name: my-prompt - description: A prompt template for data analysis. - arguments: - - name: country - description: The country to analyze. - messages: - - content: Analyze the data for {{.country}}. - `, - wantToolsFile: ToolsFile{ - Sources: nil, - AuthServices: nil, - Tools: nil, - Toolsets: nil, - Prompts: server.PromptConfigs{ - "my-prompt": &custom.Config{ - Name: "my-prompt", - Description: "A prompt template for data analysis.", - Arguments: prompts.Arguments{ - {Parameter: parameters.NewStringParameter("country", "The country to analyze.")}, - }, - Messages: []prompts.Message{ - {Role: "user", Content: "Analyze the data for {{.country}}."}, - }, - }, - }, - }, - }, - } - for _, tc := range tcs { - t.Run(tc.description, func(t *testing.T) { - toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in)) - if err != nil { - t.Fatalf("failed to parse input: %v", err) - } - if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" { - t.Fatalf("incorrect sources parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" { - t.Fatalf("incorrect authServices parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" { - t.Fatalf("incorrect tools parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.Toolsets, toolsFile.Toolsets); diff != "" { - t.Fatalf("incorrect toolsets parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.Prompts, toolsFile.Prompts); diff != "" { - t.Fatalf("incorrect prompts parse: diff %v", diff) - } - }) - } - -} - -func TestParseToolFileWithAuth(t *testing.T) { - ctx, err := testutils.ContextWithNewLogger() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - tcs := []struct { - description string - in string - wantToolsFile ToolsFile - }{ - { - description: "basic example", - in: ` - kind: sources - name: my-pg-instance - type: cloud-sql-postgres - project: my-project - region: my-region - instance: my-instance - database: my_db - user: my_user - password: my_pass ---- - kind: authServices - name: my-google-service - type: google - clientId: my-client-id ---- - kind: authServices - name: other-google-service - type: google - clientId: other-client-id ---- - kind: tools - name: example_tool - type: postgres-sql - source: my-pg-instance - description: some description - statement: | - SELECT * FROM SQL_STATEMENT; - parameters: - - name: country - type: string - description: some description - - name: id - type: integer - description: user id - authServices: - - name: my-google-service - field: user_id - - name: email - type: string - description: user email - authServices: - - name: my-google-service - field: email - - name: other-google-service - field: other_email ---- - kind: toolsets - name: example_toolset - tools: - - example_tool - `, - wantToolsFile: ToolsFile{ - Sources: server.SourceConfigs{ - "my-pg-instance": cloudsqlpgsrc.Config{ - Name: "my-pg-instance", - Type: cloudsqlpgsrc.SourceType, - Project: "my-project", - Region: "my-region", - Instance: "my-instance", - IPType: "public", - Database: "my_db", - User: "my_user", - Password: "my_pass", - }, - }, - AuthServices: server.AuthServiceConfigs{ - "my-google-service": google.Config{ - Name: "my-google-service", - Type: google.AuthServiceType, - ClientID: "my-client-id", - }, - "other-google-service": google.Config{ - Name: "other-google-service", - Type: google.AuthServiceType, - ClientID: "other-client-id", - }, - }, - Tools: server.ToolConfigs{ - "example_tool": postgressql.Config{ - Name: "example_tool", - Type: "postgres-sql", - Source: "my-pg-instance", - Description: "some description", - Statement: "SELECT * FROM SQL_STATEMENT;\n", - AuthRequired: []string{}, - Parameters: []parameters.Parameter{ - parameters.NewStringParameter("country", "some description"), - parameters.NewIntParameterWithAuth("id", "user id", []parameters.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}), - parameters.NewStringParameterWithAuth("email", "user email", []parameters.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}), - }, - }, - }, - Toolsets: server.ToolsetConfigs{ - "example_toolset": tools.ToolsetConfig{ - Name: "example_toolset", - ToolNames: []string{"example_tool"}, - }, - }, - Prompts: nil, - }, - }, - { - description: "basic example with authSources", - in: ` - sources: - my-pg-instance: - kind: cloud-sql-postgres - project: my-project - region: my-region - instance: my-instance - database: my_db - user: my_user - password: my_pass - authSources: - my-google-service: - kind: google - clientId: my-client-id - other-google-service: - kind: google - clientId: other-client-id - - tools: - example_tool: - kind: postgres-sql - source: my-pg-instance - description: some description - statement: | - SELECT * FROM SQL_STATEMENT; - parameters: - - name: country - type: string - description: some description - - name: id - type: integer - description: user id - authSources: - - name: my-google-service - field: user_id - - name: email - type: string - description: user email - authSources: - - name: my-google-service - field: email - - name: other-google-service - field: other_email - - toolsets: - example_toolset: - - example_tool - `, - wantToolsFile: ToolsFile{ - Sources: server.SourceConfigs{ - "my-pg-instance": cloudsqlpgsrc.Config{ - Name: "my-pg-instance", - Type: cloudsqlpgsrc.SourceType, - Project: "my-project", - Region: "my-region", - Instance: "my-instance", - IPType: "public", - Database: "my_db", - User: "my_user", - Password: "my_pass", - }, - }, - AuthServices: server.AuthServiceConfigs{ - "my-google-service": google.Config{ - Name: "my-google-service", - Type: google.AuthServiceType, - ClientID: "my-client-id", - }, - "other-google-service": google.Config{ - Name: "other-google-service", - Type: google.AuthServiceType, - ClientID: "other-client-id", - }, - }, - Tools: server.ToolConfigs{ - "example_tool": postgressql.Config{ - Name: "example_tool", - Type: "postgres-sql", - Source: "my-pg-instance", - Description: "some description", - Statement: "SELECT * FROM SQL_STATEMENT;\n", - AuthRequired: []string{}, - Parameters: []parameters.Parameter{ - parameters.NewStringParameter("country", "some description"), - parameters.NewIntParameterWithAuth("id", "user id", []parameters.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}), - parameters.NewStringParameterWithAuth("email", "user email", []parameters.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}), - }, - }, - }, - Toolsets: server.ToolsetConfigs{ - "example_toolset": tools.ToolsetConfig{ - Name: "example_toolset", - ToolNames: []string{"example_tool"}, - }, - }, - Prompts: nil, - }, - }, - { - description: "basic example with authRequired", - in: ` - kind: sources - name: my-pg-instance - type: cloud-sql-postgres - project: my-project - region: my-region - instance: my-instance - database: my_db - user: my_user - password: my_pass ---- - kind: authServices - name: my-google-service - type: google - clientId: my-client-id ---- - kind: authServices - name: other-google-service - type: google - clientId: other-client-id ---- - kind: tools - name: example_tool - type: postgres-sql - source: my-pg-instance - description: some description - statement: | - SELECT * FROM SQL_STATEMENT; - authRequired: - - my-google-service - parameters: - - name: country - type: string - description: some description - - name: id - type: integer - description: user id - authServices: - - name: my-google-service - field: user_id - - name: email - type: string - description: user email - authServices: - - name: my-google-service - field: email - - name: other-google-service - field: other_email ---- - kind: toolsets - name: example_toolset - tools: - - example_tool - `, - wantToolsFile: ToolsFile{ - Sources: server.SourceConfigs{ - "my-pg-instance": cloudsqlpgsrc.Config{ - Name: "my-pg-instance", - Type: cloudsqlpgsrc.SourceType, - Project: "my-project", - Region: "my-region", - Instance: "my-instance", - IPType: "public", - Database: "my_db", - User: "my_user", - Password: "my_pass", - }, - }, - AuthServices: server.AuthServiceConfigs{ - "my-google-service": google.Config{ - Name: "my-google-service", - Type: google.AuthServiceType, - ClientID: "my-client-id", - }, - "other-google-service": google.Config{ - Name: "other-google-service", - Type: google.AuthServiceType, - ClientID: "other-client-id", - }, - }, - Tools: server.ToolConfigs{ - "example_tool": postgressql.Config{ - Name: "example_tool", - Type: "postgres-sql", - Source: "my-pg-instance", - Description: "some description", - Statement: "SELECT * FROM SQL_STATEMENT;\n", - AuthRequired: []string{"my-google-service"}, - Parameters: []parameters.Parameter{ - parameters.NewStringParameter("country", "some description"), - parameters.NewIntParameterWithAuth("id", "user id", []parameters.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}), - parameters.NewStringParameterWithAuth("email", "user email", []parameters.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}), - }, - }, - }, - Toolsets: server.ToolsetConfigs{ - "example_toolset": tools.ToolsetConfig{ - Name: "example_toolset", - ToolNames: []string{"example_tool"}, - }, - }, - Prompts: nil, - }, - }, - } - for _, tc := range tcs { - t.Run(tc.description, func(t *testing.T) { - toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in)) - if err != nil { - t.Fatalf("failed to parse input: %v", err) - } - if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" { - t.Fatalf("incorrect sources parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" { - t.Fatalf("incorrect authServices parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" { - t.Fatalf("incorrect tools parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.Toolsets, toolsFile.Toolsets); diff != "" { - t.Fatalf("incorrect toolsets parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.Prompts, toolsFile.Prompts); diff != "" { - t.Fatalf("incorrect prompts parse: diff %v", diff) - } - }) - } - -} - -func TestEnvVarReplacement(t *testing.T) { - ctx, err := testutils.ContextWithNewLogger() - t.Setenv("TestHeader", "ACTUAL_HEADER") - t.Setenv("API_KEY", "ACTUAL_API_KEY") - t.Setenv("clientId", "ACTUAL_CLIENT_ID") - t.Setenv("clientId2", "ACTUAL_CLIENT_ID_2") - t.Setenv("toolset_name", "ACTUAL_TOOLSET_NAME") - t.Setenv("cat_string", "cat") - t.Setenv("food_string", "food") - t.Setenv("TestHeader", "ACTUAL_HEADER") - t.Setenv("prompt_name", "ACTUAL_PROMPT_NAME") - t.Setenv("prompt_content", "ACTUAL_CONTENT") - - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - tcs := []struct { - description string - in string - wantToolsFile ToolsFile - }{ - { - description: "file with env var example", - in: ` - sources: - my-http-instance: - kind: http - baseUrl: http://test_server/ - timeout: 10s - headers: - Authorization: ${TestHeader} - queryParams: - api-key: ${API_KEY} - authServices: - my-google-service: - kind: google - clientId: ${clientId} - other-google-service: - kind: google - clientId: ${clientId2} - - tools: - example_tool: - kind: http - source: my-instance - method: GET - path: "search?name=alice&pet=${cat_string}" - description: some description - authRequired: - - my-google-auth-service - - other-auth-service - queryParams: - - name: country - type: string - description: some description - authServices: - - name: my-google-auth-service - field: user_id - - name: other-auth-service - field: user_id - requestBody: | - { - "age": {{.age}}, - "city": "{{.city}}", - "food": "${food_string}", - "other": "$OTHER" - } - bodyParams: - - name: age - type: integer - description: age num - - name: city - type: string - description: city string - headers: - Authorization: API_KEY - Content-Type: application/json - headerParams: - - name: Language - type: string - description: language string - - toolsets: - ${toolset_name}: - - example_tool - - - prompts: - ${prompt_name}: - description: A test prompt for {{.name}}. - messages: - - role: user - content: ${prompt_content} - `, - wantToolsFile: ToolsFile{ - Sources: server.SourceConfigs{ - "my-http-instance": httpsrc.Config{ - Name: "my-http-instance", - Type: httpsrc.SourceType, - BaseURL: "http://test_server/", - Timeout: "10s", - DefaultHeaders: map[string]string{"Authorization": "ACTUAL_HEADER"}, - QueryParams: map[string]string{"api-key": "ACTUAL_API_KEY"}, - }, - }, - AuthServices: server.AuthServiceConfigs{ - "my-google-service": google.Config{ - Name: "my-google-service", - Type: google.AuthServiceType, - ClientID: "ACTUAL_CLIENT_ID", - }, - "other-google-service": google.Config{ - Name: "other-google-service", - Type: google.AuthServiceType, - ClientID: "ACTUAL_CLIENT_ID_2", - }, - }, - Tools: server.ToolConfigs{ - "example_tool": http.Config{ - Name: "example_tool", - Type: "http", - Source: "my-instance", - Method: "GET", - Path: "search?name=alice&pet=cat", - Description: "some description", - AuthRequired: []string{"my-google-auth-service", "other-auth-service"}, - QueryParams: []parameters.Parameter{ - parameters.NewStringParameterWithAuth("country", "some description", - []parameters.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"}, - {Name: "other-auth-service", Field: "user_id"}}), - }, - RequestBody: `{ - "age": {{.age}}, - "city": "{{.city}}", - "food": "food", - "other": "$OTHER" -} -`, - BodyParams: []parameters.Parameter{parameters.NewIntParameter("age", "age num"), parameters.NewStringParameter("city", "city string")}, - Headers: map[string]string{"Authorization": "API_KEY", "Content-Type": "application/json"}, - HeaderParams: []parameters.Parameter{parameters.NewStringParameter("Language", "language string")}, - }, - }, - Toolsets: server.ToolsetConfigs{ - "ACTUAL_TOOLSET_NAME": tools.ToolsetConfig{ - Name: "ACTUAL_TOOLSET_NAME", - ToolNames: []string{"example_tool"}, - }, - }, - Prompts: server.PromptConfigs{ - "ACTUAL_PROMPT_NAME": &custom.Config{ - Name: "ACTUAL_PROMPT_NAME", - Description: "A test prompt for {{.name}}.", - Messages: []prompts.Message{ - { - Role: "user", - Content: "ACTUAL_CONTENT", - }, - }, - Arguments: nil, - }, - }, - }, - }, - { - description: "file with env var example toolsfile v2", - in: ` - kind: sources - name: my-http-instance - type: http - baseUrl: http://test_server/ - timeout: 10s - headers: - Authorization: ${TestHeader} - queryParams: - api-key: ${API_KEY} ---- - kind: authServices - name: my-google-service - type: google - clientId: ${clientId} ---- - kind: authServices - name: other-google-service - type: google - clientId: ${clientId2} ---- - kind: tools - name: example_tool - type: http - source: my-instance - method: GET - path: "search?name=alice&pet=${cat_string}" - description: some description - authRequired: - - my-google-auth-service - - other-auth-service - queryParams: - - name: country - type: string - description: some description - authServices: - - name: my-google-auth-service - field: user_id - - name: other-auth-service - field: user_id - requestBody: | - { - "age": {{.age}}, - "city": "{{.city}}", - "food": "${food_string}", - "other": "$OTHER" - } - bodyParams: - - name: age - type: integer - description: age num - - name: city - type: string - description: city string - headers: - Authorization: API_KEY - Content-Type: application/json - headerParams: - - name: Language - type: string - description: language string ---- - kind: toolsets - name: ${toolset_name} - tools: - - example_tool ---- - kind: prompts - name: ${prompt_name} - description: A test prompt for {{.name}}. - messages: - - role: user - content: ${prompt_content} - `, - wantToolsFile: ToolsFile{ - Sources: server.SourceConfigs{ - "my-http-instance": httpsrc.Config{ - Name: "my-http-instance", - Type: httpsrc.SourceType, - BaseURL: "http://test_server/", - Timeout: "10s", - DefaultHeaders: map[string]string{"Authorization": "ACTUAL_HEADER"}, - QueryParams: map[string]string{"api-key": "ACTUAL_API_KEY"}, - }, - }, - AuthServices: server.AuthServiceConfigs{ - "my-google-service": google.Config{ - Name: "my-google-service", - Type: google.AuthServiceType, - ClientID: "ACTUAL_CLIENT_ID", - }, - "other-google-service": google.Config{ - Name: "other-google-service", - Type: google.AuthServiceType, - ClientID: "ACTUAL_CLIENT_ID_2", - }, - }, - Tools: server.ToolConfigs{ - "example_tool": http.Config{ - Name: "example_tool", - Type: "http", - Source: "my-instance", - Method: "GET", - Path: "search?name=alice&pet=cat", - Description: "some description", - AuthRequired: []string{"my-google-auth-service", "other-auth-service"}, - QueryParams: []parameters.Parameter{ - parameters.NewStringParameterWithAuth("country", "some description", - []parameters.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"}, - {Name: "other-auth-service", Field: "user_id"}}), - }, - RequestBody: `{ - "age": {{.age}}, - "city": "{{.city}}", - "food": "food", - "other": "$OTHER" -} -`, - BodyParams: []parameters.Parameter{parameters.NewIntParameter("age", "age num"), parameters.NewStringParameter("city", "city string")}, - Headers: map[string]string{"Authorization": "API_KEY", "Content-Type": "application/json"}, - HeaderParams: []parameters.Parameter{parameters.NewStringParameter("Language", "language string")}, - }, - }, - Toolsets: server.ToolsetConfigs{ - "ACTUAL_TOOLSET_NAME": tools.ToolsetConfig{ - Name: "ACTUAL_TOOLSET_NAME", - ToolNames: []string{"example_tool"}, - }, - }, - Prompts: server.PromptConfigs{ - "ACTUAL_PROMPT_NAME": &custom.Config{ - Name: "ACTUAL_PROMPT_NAME", - Description: "A test prompt for {{.name}}.", - Messages: []prompts.Message{ - { - Role: "user", - Content: "ACTUAL_CONTENT", - }, - }, - Arguments: nil, - }, - }, - }, - }, - } - for _, tc := range tcs { - t.Run(tc.description, func(t *testing.T) { - toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in)) - if err != nil { - t.Fatalf("failed to parse input: %v", err) - } - if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" { - t.Fatalf("incorrect sources parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" { - t.Fatalf("incorrect authServices parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" { - t.Fatalf("incorrect tools parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.Toolsets, toolsFile.Toolsets); diff != "" { - t.Fatalf("incorrect toolsets parse: diff %v", diff) - } - if diff := cmp.Diff(tc.wantToolsFile.Prompts, toolsFile.Prompts); diff != "" { - t.Fatalf("incorrect prompts parse: diff %v", diff) - } - }) - } -} - // normalizeFilepaths is a helper function to allow same filepath formats for Mac and Windows. // this prevents needing multiple "want" cases for TestResolveWatcherInputs func normalizeFilepaths(m map[string]bool) map[string]bool { @@ -2052,485 +618,6 @@ func TestSingleEdit(t *testing.T) { } } -func TestPrebuiltTools(t *testing.T) { - // Get prebuilt configs - alloydb_omni_config, _ := prebuiltconfigs.Get("alloydb-omni") - alloydb_admin_config, _ := prebuiltconfigs.Get("alloydb-postgres-admin") - alloydb_config, _ := prebuiltconfigs.Get("alloydb-postgres") - bigquery_config, _ := prebuiltconfigs.Get("bigquery") - clickhouse_config, _ := prebuiltconfigs.Get("clickhouse") - cloudsqlpg_config, _ := prebuiltconfigs.Get("cloud-sql-postgres") - cloudsqlpg_admin_config, _ := prebuiltconfigs.Get("cloud-sql-postgres-admin") - cloudsqlmysql_config, _ := prebuiltconfigs.Get("cloud-sql-mysql") - cloudsqlmysql_admin_config, _ := prebuiltconfigs.Get("cloud-sql-mysql-admin") - cloudsqlmssql_config, _ := prebuiltconfigs.Get("cloud-sql-mssql") - cloudsqlmssql_admin_config, _ := prebuiltconfigs.Get("cloud-sql-mssql-admin") - dataplex_config, _ := prebuiltconfigs.Get("dataplex") - firestoreconfig, _ := prebuiltconfigs.Get("firestore") - mysql_config, _ := prebuiltconfigs.Get("mysql") - mssql_config, _ := prebuiltconfigs.Get("mssql") - looker_config, _ := prebuiltconfigs.Get("looker") - lookerca_config, _ := prebuiltconfigs.Get("looker-conversational-analytics") - postgresconfig, _ := prebuiltconfigs.Get("postgres") - spanner_config, _ := prebuiltconfigs.Get("spanner") - spannerpg_config, _ := prebuiltconfigs.Get("spanner-postgres") - mindsdb_config, _ := prebuiltconfigs.Get("mindsdb") - sqlite_config, _ := prebuiltconfigs.Get("sqlite") - neo4jconfig, _ := prebuiltconfigs.Get("neo4j") - alloydbobsvconfig, _ := prebuiltconfigs.Get("alloydb-postgres-observability") - cloudsqlpgobsvconfig, _ := prebuiltconfigs.Get("cloud-sql-postgres-observability") - cloudsqlmysqlobsvconfig, _ := prebuiltconfigs.Get("cloud-sql-mysql-observability") - cloudsqlmssqlobsvconfig, _ := prebuiltconfigs.Get("cloud-sql-mssql-observability") - serverless_spark_config, _ := prebuiltconfigs.Get("serverless-spark") - cloudhealthcare_config, _ := prebuiltconfigs.Get("cloud-healthcare") - snowflake_config, _ := prebuiltconfigs.Get("snowflake") - - // Set environment variables - t.Setenv("API_KEY", "your_api_key") - - t.Setenv("BIGQUERY_PROJECT", "your_gcp_project_id") - t.Setenv("DATAPLEX_PROJECT", "your_gcp_project_id") - t.Setenv("FIRESTORE_PROJECT", "your_gcp_project_id") - t.Setenv("FIRESTORE_DATABASE", "your_firestore_db_name") - - t.Setenv("SPANNER_PROJECT", "your_gcp_project_id") - t.Setenv("SPANNER_INSTANCE", "your_spanner_instance") - t.Setenv("SPANNER_DATABASE", "your_spanner_db") - - t.Setenv("ALLOYDB_POSTGRES_PROJECT", "your_gcp_project_id") - t.Setenv("ALLOYDB_POSTGRES_REGION", "your_gcp_region") - t.Setenv("ALLOYDB_POSTGRES_CLUSTER", "your_alloydb_cluster") - t.Setenv("ALLOYDB_POSTGRES_INSTANCE", "your_alloydb_instance") - t.Setenv("ALLOYDB_POSTGRES_DATABASE", "your_alloydb_db") - t.Setenv("ALLOYDB_POSTGRES_USER", "your_alloydb_user") - t.Setenv("ALLOYDB_POSTGRES_PASSWORD", "your_alloydb_password") - - t.Setenv("ALLOYDB_OMNI_HOST", "localhost") - t.Setenv("ALLOYDB_OMNI_PORT", "5432") - t.Setenv("ALLOYDB_OMNI_DATABASE", "your_alloydb_db") - t.Setenv("ALLOYDB_OMNI_USER", "your_alloydb_user") - t.Setenv("ALLOYDB_OMNI_PASSWORD", "your_alloydb_password") - - t.Setenv("CLICKHOUSE_PROTOCOL", "your_clickhouse_protocol") - t.Setenv("CLICKHOUSE_DATABASE", "your_clickhouse_database") - t.Setenv("CLICKHOUSE_PASSWORD", "your_clickhouse_password") - t.Setenv("CLICKHOUSE_USER", "your_clickhouse_user") - t.Setenv("CLICKHOUSE_HOST", "your_clickhosue_host") - t.Setenv("CLICKHOUSE_PORT", "8123") - - t.Setenv("CLOUD_SQL_POSTGRES_PROJECT", "your_pg_project") - t.Setenv("CLOUD_SQL_POSTGRES_INSTANCE", "your_pg_instance") - t.Setenv("CLOUD_SQL_POSTGRES_DATABASE", "your_pg_db") - t.Setenv("CLOUD_SQL_POSTGRES_REGION", "your_pg_region") - t.Setenv("CLOUD_SQL_POSTGRES_USER", "your_pg_user") - t.Setenv("CLOUD_SQL_POSTGRES_PASS", "your_pg_pass") - - t.Setenv("CLOUD_SQL_MYSQL_PROJECT", "your_gcp_project_id") - t.Setenv("CLOUD_SQL_MYSQL_REGION", "your_gcp_region") - t.Setenv("CLOUD_SQL_MYSQL_INSTANCE", "your_instance") - t.Setenv("CLOUD_SQL_MYSQL_DATABASE", "your_cloudsql_mysql_db") - t.Setenv("CLOUD_SQL_MYSQL_USER", "your_cloudsql_mysql_user") - t.Setenv("CLOUD_SQL_MYSQL_PASSWORD", "your_cloudsql_mysql_password") - - t.Setenv("CLOUD_SQL_MSSQL_PROJECT", "your_gcp_project_id") - t.Setenv("CLOUD_SQL_MSSQL_REGION", "your_gcp_region") - t.Setenv("CLOUD_SQL_MSSQL_INSTANCE", "your_cloudsql_mssql_instance") - t.Setenv("CLOUD_SQL_MSSQL_DATABASE", "your_cloudsql_mssql_db") - t.Setenv("CLOUD_SQL_MSSQL_IP_ADDRESS", "127.0.0.1") - t.Setenv("CLOUD_SQL_MSSQL_USER", "your_cloudsql_mssql_user") - t.Setenv("CLOUD_SQL_MSSQL_PASSWORD", "your_cloudsql_mssql_password") - t.Setenv("CLOUD_SQL_POSTGRES_PASSWORD", "your_cloudsql_pg_password") - - t.Setenv("SERVERLESS_SPARK_PROJECT", "your_gcp_project_id") - t.Setenv("SERVERLESS_SPARK_LOCATION", "your_gcp_location") - - t.Setenv("POSTGRES_HOST", "localhost") - t.Setenv("POSTGRES_PORT", "5432") - t.Setenv("POSTGRES_DATABASE", "your_postgres_db") - t.Setenv("POSTGRES_USER", "your_postgres_user") - t.Setenv("POSTGRES_PASSWORD", "your_postgres_password") - - t.Setenv("MYSQL_HOST", "localhost") - t.Setenv("MYSQL_PORT", "3306") - t.Setenv("MYSQL_DATABASE", "your_mysql_db") - t.Setenv("MYSQL_USER", "your_mysql_user") - t.Setenv("MYSQL_PASSWORD", "your_mysql_password") - - t.Setenv("MSSQL_HOST", "localhost") - t.Setenv("MSSQL_PORT", "1433") - t.Setenv("MSSQL_DATABASE", "your_mssql_db") - t.Setenv("MSSQL_USER", "your_mssql_user") - t.Setenv("MSSQL_PASSWORD", "your_mssql_password") - - t.Setenv("MINDSDB_HOST", "localhost") - t.Setenv("MINDSDB_PORT", "47334") - t.Setenv("MINDSDB_DATABASE", "your_mindsdb_db") - t.Setenv("MINDSDB_USER", "your_mindsdb_user") - t.Setenv("MINDSDB_PASS", "your_mindsdb_password") - - t.Setenv("LOOKER_BASE_URL", "https://your_company.looker.com") - t.Setenv("LOOKER_CLIENT_ID", "your_looker_client_id") - t.Setenv("LOOKER_CLIENT_SECRET", "your_looker_client_secret") - t.Setenv("LOOKER_VERIFY_SSL", "true") - - t.Setenv("LOOKER_PROJECT", "your_project_id") - t.Setenv("LOOKER_LOCATION", "us") - - t.Setenv("SQLITE_DATABASE", "test.db") - - t.Setenv("NEO4J_URI", "bolt://localhost:7687") - t.Setenv("NEO4J_DATABASE", "neo4j") - t.Setenv("NEO4J_USERNAME", "your_neo4j_user") - t.Setenv("NEO4J_PASSWORD", "your_neo4j_password") - - t.Setenv("CLOUD_HEALTHCARE_PROJECT", "your_gcp_project_id") - t.Setenv("CLOUD_HEALTHCARE_REGION", "your_gcp_region") - t.Setenv("CLOUD_HEALTHCARE_DATASET", "your_healthcare_dataset") - - t.Setenv("SNOWFLAKE_ACCOUNT", "your_account") - t.Setenv("SNOWFLAKE_USER", "your_username") - t.Setenv("SNOWFLAKE_PASSWORD", "your_pass") - t.Setenv("SNOWFLAKE_DATABASE", "your_db") - t.Setenv("SNOWFLAKE_SCHEMA", "your_schema") - t.Setenv("SNOWFLAKE_WAREHOUSE", "your_wh") - t.Setenv("SNOWFLAKE_ROLE", "your_role") - - ctx, err := testutils.ContextWithNewLogger() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - tcs := []struct { - name string - in []byte - wantToolset server.ToolsetConfigs - }{ - { - name: "alloydb omni prebuilt tools", - in: alloydb_omni_config, - wantToolset: server.ToolsetConfigs{ - "alloydb_omni_database_tools": tools.ToolsetConfig{ - Name: "alloydb_omni_database_tools", - ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_columnar_configurations", "list_columnar_recommended_columns", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"}, - }, - }, - }, - { - name: "alloydb postgres admin prebuilt tools", - in: alloydb_admin_config, - wantToolset: server.ToolsetConfigs{ - "alloydb_postgres_admin_tools": tools.ToolsetConfig{ - Name: "alloydb_postgres_admin_tools", - ToolNames: []string{"create_cluster", "wait_for_operation", "create_instance", "list_clusters", "list_instances", "list_users", "create_user", "get_cluster", "get_instance", "get_user"}, - }, - }, - }, - { - name: "cloudsql pg admin prebuilt tools", - in: cloudsqlpg_admin_config, - wantToolset: server.ToolsetConfigs{ - "cloud_sql_postgres_admin_tools": tools.ToolsetConfig{ - Name: "cloud_sql_postgres_admin_tools", - ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance", "create_backup", "restore_backup"}, - }, - }, - }, - { - name: "cloudsql mysql admin prebuilt tools", - in: cloudsqlmysql_admin_config, - wantToolset: server.ToolsetConfigs{ - "cloud_sql_mysql_admin_tools": tools.ToolsetConfig{ - Name: "cloud_sql_mysql_admin_tools", - ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup", "restore_backup"}, - }, - }, - }, - { - name: "cloudsql mssql admin prebuilt tools", - in: cloudsqlmssql_admin_config, - wantToolset: server.ToolsetConfigs{ - "cloud_sql_mssql_admin_tools": tools.ToolsetConfig{ - Name: "cloud_sql_mssql_admin_tools", - ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup", "restore_backup"}, - }, - }, - }, - { - name: "alloydb prebuilt tools", - in: alloydb_config, - wantToolset: server.ToolsetConfigs{ - "alloydb_postgres_database_tools": tools.ToolsetConfig{ - Name: "alloydb_postgres_database_tools", - ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"}, - }, - }, - }, - { - name: "bigquery prebuilt tools", - in: bigquery_config, - wantToolset: server.ToolsetConfigs{ - "bigquery_database_tools": tools.ToolsetConfig{ - Name: "bigquery_database_tools", - ToolNames: []string{"analyze_contribution", "ask_data_insights", "execute_sql", "forecast", "get_dataset_info", "get_table_info", "list_dataset_ids", "list_table_ids", "search_catalog"}, - }, - }, - }, - { - name: "clickhouse prebuilt tools", - in: clickhouse_config, - wantToolset: server.ToolsetConfigs{ - "clickhouse_database_tools": tools.ToolsetConfig{ - Name: "clickhouse_database_tools", - ToolNames: []string{"execute_sql", "list_databases", "list_tables"}, - }, - }, - }, - { - name: "cloudsqlpg prebuilt tools", - in: cloudsqlpg_config, - wantToolset: server.ToolsetConfigs{ - "cloud_sql_postgres_database_tools": tools.ToolsetConfig{ - Name: "cloud_sql_postgres_database_tools", - ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"}, - }, - }, - }, - { - name: "cloudsqlmysql prebuilt tools", - in: cloudsqlmysql_config, - wantToolset: server.ToolsetConfigs{ - "cloud_sql_mysql_database_tools": tools.ToolsetConfig{ - Name: "cloud_sql_mysql_database_tools", - ToolNames: []string{"execute_sql", "list_tables", "get_query_plan", "list_active_queries", "list_tables_missing_unique_indexes", "list_table_fragmentation"}, - }, - }, - }, - { - name: "cloudsqlmssql prebuilt tools", - in: cloudsqlmssql_config, - wantToolset: server.ToolsetConfigs{ - "cloud_sql_mssql_database_tools": tools.ToolsetConfig{ - Name: "cloud_sql_mssql_database_tools", - ToolNames: []string{"execute_sql", "list_tables"}, - }, - }, - }, - { - name: "dataplex prebuilt tools", - in: dataplex_config, - wantToolset: server.ToolsetConfigs{ - "dataplex_tools": tools.ToolsetConfig{ - Name: "dataplex_tools", - ToolNames: []string{"search_entries", "lookup_entry", "search_aspect_types"}, - }, - }, - }, - { - name: "serverless spark prebuilt tools", - in: serverless_spark_config, - wantToolset: server.ToolsetConfigs{ - "serverless_spark_tools": tools.ToolsetConfig{ - Name: "serverless_spark_tools", - ToolNames: []string{"list_batches", "get_batch", "cancel_batch", "create_pyspark_batch", "create_spark_batch"}, - }, - }, - }, - { - name: "firestore prebuilt tools", - in: firestoreconfig, - wantToolset: server.ToolsetConfigs{ - "firestore_database_tools": tools.ToolsetConfig{ - Name: "firestore_database_tools", - ToolNames: []string{"get_documents", "add_documents", "update_document", "list_collections", "delete_documents", "query_collection", "get_rules", "validate_rules"}, - }, - }, - }, - { - name: "mysql prebuilt tools", - in: mysql_config, - wantToolset: server.ToolsetConfigs{ - "mysql_database_tools": tools.ToolsetConfig{ - Name: "mysql_database_tools", - ToolNames: []string{"execute_sql", "list_tables", "get_query_plan", "list_active_queries", "list_tables_missing_unique_indexes", "list_table_fragmentation"}, - }, - }, - }, - { - name: "mssql prebuilt tools", - in: mssql_config, - wantToolset: server.ToolsetConfigs{ - "mssql_database_tools": tools.ToolsetConfig{ - Name: "mssql_database_tools", - ToolNames: []string{"execute_sql", "list_tables"}, - }, - }, - }, - { - name: "looker prebuilt tools", - in: looker_config, - wantToolset: server.ToolsetConfigs{ - "looker_tools": tools.ToolsetConfig{ - Name: "looker_tools", - ToolNames: []string{"get_models", "get_explores", "get_dimensions", "get_measures", "get_filters", "get_parameters", "query", "query_sql", "query_url", "get_looks", "run_look", "make_look", "get_dashboards", "run_dashboard", "make_dashboard", "add_dashboard_element", "add_dashboard_filter", "generate_embed_url", "health_pulse", "health_analyze", "health_vacuum", "dev_mode", "get_projects", "get_project_files", "get_project_file", "create_project_file", "update_project_file", "delete_project_file", "validate_project", "get_connections", "get_connection_schemas", "get_connection_databases", "get_connection_tables", "get_connection_table_columns"}, - }, - }, - }, - { - name: "looker-conversational-analytics prebuilt tools", - in: lookerca_config, - wantToolset: server.ToolsetConfigs{ - "looker_conversational_analytics_tools": tools.ToolsetConfig{ - Name: "looker_conversational_analytics_tools", - ToolNames: []string{"ask_data_insights", "get_models", "get_explores"}, - }, - }, - }, - { - name: "postgres prebuilt tools", - in: postgresconfig, - wantToolset: server.ToolsetConfigs{ - "postgres_database_tools": tools.ToolsetConfig{ - Name: "postgres_database_tools", - ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"}, - }, - }, - }, - { - name: "spanner prebuilt tools", - in: spanner_config, - wantToolset: server.ToolsetConfigs{ - "spanner-database-tools": tools.ToolsetConfig{ - Name: "spanner-database-tools", - ToolNames: []string{"execute_sql", "execute_sql_dql", "list_tables", "list_graphs"}, - }, - }, - }, - { - name: "spanner pg prebuilt tools", - in: spannerpg_config, - wantToolset: server.ToolsetConfigs{ - "spanner_postgres_database_tools": tools.ToolsetConfig{ - Name: "spanner_postgres_database_tools", - ToolNames: []string{"execute_sql", "execute_sql_dql", "list_tables"}, - }, - }, - }, - { - name: "mindsdb prebuilt tools", - in: mindsdb_config, - wantToolset: server.ToolsetConfigs{ - "mindsdb-tools": tools.ToolsetConfig{ - Name: "mindsdb-tools", - ToolNames: []string{"mindsdb-execute-sql", "mindsdb-sql"}, - }, - }, - }, - { - name: "sqlite prebuilt tools", - in: sqlite_config, - wantToolset: server.ToolsetConfigs{ - "sqlite_database_tools": tools.ToolsetConfig{ - Name: "sqlite_database_tools", - ToolNames: []string{"execute_sql", "list_tables"}, - }, - }, - }, - { - name: "neo4j prebuilt tools", - in: neo4jconfig, - wantToolset: server.ToolsetConfigs{ - "neo4j_database_tools": tools.ToolsetConfig{ - Name: "neo4j_database_tools", - ToolNames: []string{"execute_cypher", "get_schema"}, - }, - }, - }, - { - name: "alloydb postgres observability prebuilt tools", - in: alloydbobsvconfig, - wantToolset: server.ToolsetConfigs{ - "alloydb_postgres_cloud_monitoring_tools": tools.ToolsetConfig{ - Name: "alloydb_postgres_cloud_monitoring_tools", - ToolNames: []string{"get_system_metrics", "get_query_metrics"}, - }, - }, - }, - { - name: "cloudsql postgres observability prebuilt tools", - in: cloudsqlpgobsvconfig, - wantToolset: server.ToolsetConfigs{ - "cloud_sql_postgres_cloud_monitoring_tools": tools.ToolsetConfig{ - Name: "cloud_sql_postgres_cloud_monitoring_tools", - ToolNames: []string{"get_system_metrics", "get_query_metrics"}, - }, - }, - }, - { - name: "cloudsql mysql observability prebuilt tools", - in: cloudsqlmysqlobsvconfig, - wantToolset: server.ToolsetConfigs{ - "cloud_sql_mysql_cloud_monitoring_tools": tools.ToolsetConfig{ - Name: "cloud_sql_mysql_cloud_monitoring_tools", - ToolNames: []string{"get_system_metrics", "get_query_metrics"}, - }, - }, - }, - { - name: "cloudsql mssql observability prebuilt tools", - in: cloudsqlmssqlobsvconfig, - wantToolset: server.ToolsetConfigs{ - "cloud_sql_mssql_cloud_monitoring_tools": tools.ToolsetConfig{ - Name: "cloud_sql_mssql_cloud_monitoring_tools", - ToolNames: []string{"get_system_metrics"}, - }, - }, - }, - { - name: "cloud healthcare prebuilt tools", - in: cloudhealthcare_config, - wantToolset: server.ToolsetConfigs{ - "cloud_healthcare_dataset_tools": tools.ToolsetConfig{ - Name: "cloud_healthcare_dataset_tools", - ToolNames: []string{"get_dataset", "list_dicom_stores", "list_fhir_stores"}, - }, - "cloud_healthcare_fhir_tools": tools.ToolsetConfig{ - Name: "cloud_healthcare_fhir_tools", - ToolNames: []string{"get_fhir_store", "get_fhir_store_metrics", "get_fhir_resource", "fhir_patient_search", "fhir_patient_everything", "fhir_fetch_page"}, - }, - "cloud_healthcare_dicom_tools": tools.ToolsetConfig{ - Name: "cloud_healthcare_dicom_tools", - ToolNames: []string{"get_dicom_store", "get_dicom_store_metrics", "search_dicom_studies", "search_dicom_series", "search_dicom_instances", "retrieve_rendered_dicom_instance"}, - }, - }, - }, - { - name: "Snowflake prebuilt tool", - in: snowflake_config, - wantToolset: server.ToolsetConfigs{ - "snowflake_tools": tools.ToolsetConfig{ - Name: "snowflake_tools", - ToolNames: []string{"execute_sql", "list_tables"}, - }, - }, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - toolsFile, err := parseToolsFile(ctx, tc.in) - if err != nil { - t.Fatalf("failed to parse input: %v", err) - } - if diff := cmp.Diff(tc.wantToolset, toolsFile.Toolsets); diff != "" { - t.Fatalf("incorrect tools parse: diff %v", diff) - } - // Prebuilt configs do not have prompts, so assert empty maps. - if len(toolsFile.Prompts) != 0 { - t.Fatalf("expected empty prompts map for prebuilt config, got: %v", toolsFile.Prompts) - } - }) - } -} - func TestMutuallyExclusiveFlags(t *testing.T) { testCases := []struct { desc string @@ -2551,7 +638,9 @@ func TestMutuallyExclusiveFlags(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - cmd := NewCommand() + buf := new(bytes.Buffer) + opts := cli.NewToolboxOptions(cli.WithIOStreams(buf, buf)) + cmd := NewCommand(opts) cmd.SetArgs(tc.args) err := cmd.Execute() if err == nil { @@ -2566,7 +655,9 @@ func TestMutuallyExclusiveFlags(t *testing.T) { func TestFileLoadingErrors(t *testing.T) { t.Run("non-existent tools-file", func(t *testing.T) { - cmd := NewCommand() + buf := new(bytes.Buffer) + opts := cli.NewToolboxOptions(cli.WithIOStreams(buf, buf)) + cmd := NewCommand(opts) // Use a file that is guaranteed not to exist nonExistentFile := filepath.Join(t.TempDir(), "non-existent-tools.yaml") cmd.SetArgs([]string{"--tools-file", nonExistentFile}) @@ -2581,7 +672,9 @@ func TestFileLoadingErrors(t *testing.T) { }) t.Run("non-existent tools-folder", func(t *testing.T) { - cmd := NewCommand() + buf := new(bytes.Buffer) + opts := cli.NewToolboxOptions(cli.WithIOStreams(buf, buf)) + cmd := NewCommand(opts) nonExistentFolder := filepath.Join(t.TempDir(), "non-existent-folder") cmd.SetArgs([]string{"--tools-folder", nonExistentFolder}) @@ -2595,94 +688,6 @@ func TestFileLoadingErrors(t *testing.T) { }) } -func TestMergeToolsFiles(t *testing.T) { - file1 := ToolsFile{ - Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}}, - Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}}, - Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}}, - EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}}, - } - file2 := ToolsFile{ - AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}}, - Tools: server.ToolConfigs{"tool2": http.Config{Name: "tool2"}}, - Toolsets: server.ToolsetConfigs{"set2": tools.ToolsetConfig{Name: "set2"}}, - } - fileWithConflicts := ToolsFile{ - Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}}, - Tools: server.ToolConfigs{"tool2": http.Config{Name: "tool2"}}, - } - - testCases := []struct { - name string - files []ToolsFile - want ToolsFile - wantErr bool - }{ - { - name: "merge two distinct files", - files: []ToolsFile{file1, file2}, - want: ToolsFile{ - Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}}, - AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}}, - Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}, "tool2": http.Config{Name: "tool2"}}, - Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}, "set2": tools.ToolsetConfig{Name: "set2"}}, - Prompts: server.PromptConfigs{}, - EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}}, - }, - wantErr: false, - }, - { - name: "merge with conflicts", - files: []ToolsFile{file1, file2, fileWithConflicts}, - wantErr: true, - }, - { - name: "merge single file", - files: []ToolsFile{file1}, - want: ToolsFile{ - Sources: file1.Sources, - AuthServices: make(server.AuthServiceConfigs), - EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}}, - Tools: file1.Tools, - Toolsets: file1.Toolsets, - Prompts: server.PromptConfigs{}, - }, - }, - { - name: "merge empty list", - files: []ToolsFile{}, - want: ToolsFile{ - Sources: make(server.SourceConfigs), - AuthServices: make(server.AuthServiceConfigs), - EmbeddingModels: make(server.EmbeddingModelConfigs), - Tools: make(server.ToolConfigs), - Toolsets: make(server.ToolsetConfigs), - Prompts: server.PromptConfigs{}, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - got, err := mergeToolsFiles(tc.files...) - if (err != nil) != tc.wantErr { - t.Fatalf("mergeToolsFiles() error = %v, wantErr %v", err, tc.wantErr) - } - if !tc.wantErr { - if diff := cmp.Diff(tc.want, got); diff != "" { - t.Errorf("mergeToolsFiles() mismatch (-want +got):\n%s", diff) - } - } else { - if err == nil { - t.Fatal("expected an error for conflicting files but got none") - } - if !strings.Contains(err.Error(), "resource conflicts detected") { - t.Errorf("expected conflict error, but got: %v", err) - } - } - }) - } -} func TestPrebuiltAndCustomTools(t *testing.T) { t.Setenv("SQLITE_DATABASE", "test.db") // Setup custom tools file @@ -2848,7 +853,7 @@ authSources: ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() - cmd, output, err := invokeCommandWithContext(ctx, tc.args) + _, opts, output, err := invokeCommandWithContext(ctx, tc.args) if tc.wantErr { if err == nil { @@ -2865,7 +870,7 @@ authSources: t.Errorf("server did not start successfully (no ready message found). Output:\n%s", output) } if tc.cfgCheck != nil { - if err := tc.cfgCheck(cmd.cfg); err != nil { + if err := tc.cfgCheck(opts.Cfg); err != nil { t.Errorf("config check failed: %v", err) } } @@ -2899,7 +904,7 @@ func TestDefaultToolsFileBehavior(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() - _, output, err := invokeCommandWithContext(ctx, tc.args) + _, _, output, err := invokeCommandWithContext(ctx, tc.args) if tc.expectRun { if err != nil && err != context.DeadlineExceeded && err != context.Canceled { @@ -2921,114 +926,29 @@ func TestDefaultToolsFileBehavior(t *testing.T) { } } -func TestParameterReferenceValidation(t *testing.T) { - ctx, err := testutils.ContextWithNewLogger() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } +func TestSubcommandWiring(t *testing.T) { + buf := new(bytes.Buffer) + opts := cli.NewToolboxOptions(cli.WithIOStreams(buf, buf)) + baseCmd := NewCommand(opts) - // Base template - baseYaml := ` -sources: - dummy-source: - kind: http - baseUrl: http://example.com -tools: - test-tool: - kind: postgres-sql - source: dummy-source - description: test tool - statement: SELECT 1; - parameters: -%s` - - tcs := []struct { - desc string - params string - wantErr bool - errSubstr string + tests := []struct { + args []string + expectedName string }{ - { - desc: "valid backward reference", - params: ` - - name: source_param - type: string - description: source - - name: copy_param - type: string - description: copy - valueFromParam: source_param`, - wantErr: false, - }, - { - desc: "valid forward reference (out of order)", - params: ` - - name: copy_param - type: string - description: copy - valueFromParam: source_param - - name: source_param - type: string - description: source`, - wantErr: false, - }, - { - desc: "invalid missing reference", - params: ` - - name: copy_param - type: string - description: copy - valueFromParam: non_existent_param`, - wantErr: true, - errSubstr: "references '\"non_existent_param\"' in the 'valueFromParam' field", - }, - { - desc: "invalid self reference", - params: ` - - name: myself - type: string - description: self - valueFromParam: myself`, - wantErr: true, - errSubstr: "parameter \"myself\" cannot copy value from itself", - }, - { - desc: "multiple valid references", - params: ` - - name: a - type: string - description: a - - name: b - type: string - description: b - valueFromParam: a - - name: c - type: string - description: c - valueFromParam: a`, - wantErr: false, - }, + {[]string{"invoke"}, "invoke"}, + {[]string{"skills-generate"}, "skills-generate"}, } - for _, tc := range tcs { - t.Run(tc.desc, func(t *testing.T) { - // Indent parameters to match YAML structure - yamlContent := fmt.Sprintf(baseYaml, tc.params) + for _, tc := range tests { + // Find returns the Command struct and the remaining args + cmd, _, err := baseCmd.Find(tc.args) - _, err := parseToolsFile(ctx, []byte(yamlContent)) + if err != nil { + t.Fatalf("Failed to find command %v: %v", tc.args, err) + } - if tc.wantErr { - if err == nil { - t.Fatal("expected error, got nil") - } - if !strings.Contains(err.Error(), tc.errSubstr) { - t.Errorf("error %q does not contain expected substring %q", err.Error(), tc.errSubstr) - } - } else { - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - } - }) + if cmd.Name() != tc.expectedName { + t.Errorf("Expected command name %q, got %q", tc.expectedName, cmd.Name()) + } } } diff --git a/internal/cli/invoke/command.go b/internal/cli/invoke/command.go index 22ab8e55d3..057e6cbd8c 100644 --- a/internal/cli/invoke/command.go +++ b/internal/cli/invoke/command.go @@ -18,37 +18,15 @@ import ( "context" "encoding/json" "fmt" - "io" - "github.com/googleapis/genai-toolbox/internal/log" + "github.com/googleapis/genai-toolbox/internal/cli" "github.com/googleapis/genai-toolbox/internal/server" "github.com/googleapis/genai-toolbox/internal/server/resources" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/spf13/cobra" ) -// RootCommand defines the interface for required by invoke subcommand. -// This allows subcommands to access shared resources and functionality without -// direct coupling to the root command's implementation. -type RootCommand interface { - // Config returns a copy of the current server configuration. - Config() server.ServerConfig - - // Out returns the writer used for standard output. - Out() io.Writer - - // LoadConfig loads and merges the configuration from files, folders, and prebuilts. - LoadConfig(ctx context.Context) error - - // Setup initializes the runtime environment, including logging and telemetry. - // It returns the updated context and a shutdown function to be called when finished. - Setup(ctx context.Context) (context.Context, func(context.Context) error, error) - - // Logger returns the logger instance. - Logger() log.Logger -} - -func NewCommand(rootCmd RootCommand) *cobra.Command { +func NewCommand(opts *cli.ToolboxOptions) *cobra.Command { cmd := &cobra.Command{ Use: "invoke [params]", Short: "Execute a tool directly", @@ -58,17 +36,17 @@ Example: toolbox invoke my-tool '{"param1": "value1"}'`, Args: cobra.MinimumNArgs(1), RunE: func(c *cobra.Command, args []string) error { - return runInvoke(c, args, rootCmd) + return runInvoke(c, args, opts) }, } return cmd } -func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error { +func runInvoke(cmd *cobra.Command, args []string, opts *cli.ToolboxOptions) error { ctx, cancel := context.WithCancel(cmd.Context()) defer cancel() - ctx, shutdown, err := rootCmd.Setup(ctx) + ctx, shutdown, err := opts.Setup(ctx) if err != nil { return err } @@ -76,16 +54,16 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error { _ = shutdown(ctx) }() - // Load and merge tool configurations - if err := rootCmd.LoadConfig(ctx); err != nil { + _, err = opts.LoadConfig(ctx) + if err != nil { return err } // Initialize Resources - sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, rootCmd.Config()) + sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, opts.Cfg) if err != nil { errMsg := fmt.Errorf("failed to initialize resources: %w", err) - rootCmd.Logger().ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } @@ -96,7 +74,7 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error { tool, ok := resourceMgr.GetTool(toolName) if !ok { errMsg := fmt.Errorf("tool %q not found", toolName) - rootCmd.Logger().ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } @@ -109,7 +87,7 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error { if paramsInput != "" { if err := json.Unmarshal([]byte(paramsInput), ¶ms); err != nil { errMsg := fmt.Errorf("params must be a valid JSON string: %w", err) - rootCmd.Logger().ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } } @@ -117,14 +95,14 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error { parsedParams, err := parameters.ParseParams(tool.GetParameters(), params, nil) if err != nil { errMsg := fmt.Errorf("invalid parameters: %w", err) - rootCmd.Logger().ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } parsedParams, err = tool.EmbedParams(ctx, parsedParams, resourceMgr.GetEmbeddingModelMap()) if err != nil { errMsg := fmt.Errorf("error embedding parameters: %w", err) - rootCmd.Logger().ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } @@ -132,19 +110,19 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error { requiresAuth, err := tool.RequiresClientAuthorization(resourceMgr) if err != nil { errMsg := fmt.Errorf("failed to check auth requirements: %w", err) - rootCmd.Logger().ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } if requiresAuth { errMsg := fmt.Errorf("client authorization is not supported") - rootCmd.Logger().ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } result, err := tool.Invoke(ctx, resourceMgr, parsedParams, "") if err != nil { errMsg := fmt.Errorf("tool execution failed: %w", err) - rootCmd.Logger().ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } @@ -152,10 +130,10 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error { output, err := json.MarshalIndent(result, "", " ") if err != nil { errMsg := fmt.Errorf("failed to marshal result: %w", err) - rootCmd.Logger().ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } - fmt.Fprintln(rootCmd.Out(), string(output)) + fmt.Fprintln(opts.IOStreams.Out, string(output)) return nil } diff --git a/cmd/invoke_tool_test.go b/internal/cli/invoke/command_test.go similarity index 80% rename from cmd/invoke_tool_test.go rename to internal/cli/invoke/command_test.go index 4fa47817ef..62c8cd4cb7 100644 --- a/cmd/invoke_tool_test.go +++ b/internal/cli/invoke/command_test.go @@ -12,16 +12,38 @@ // See the License for the specific language governing permissions and // limitations under the License. -package cmd +package invoke import ( - "context" + "bytes" "os" "path/filepath" "strings" "testing" + + "github.com/googleapis/genai-toolbox/internal/cli" + _ "github.com/googleapis/genai-toolbox/internal/sources/bigquery" + _ "github.com/googleapis/genai-toolbox/internal/sources/sqlite" + _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysql" + _ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqlitesql" + "github.com/spf13/cobra" ) +func invokeCommand(args []string) (string, error) { + parentCmd := &cobra.Command{Use: "toolbox"} + + buf := new(bytes.Buffer) + opts := cli.NewToolboxOptions(cli.WithIOStreams(buf, buf)) + cli.PersistentFlags(parentCmd, opts) + + cmd := NewCommand(opts) + parentCmd.AddCommand(cmd) + parentCmd.SetArgs(args) + + err := parentCmd.Execute() + return buf.String(), err +} + func TestInvokeTool(t *testing.T) { // Create a temporary tools file tmpDir := t.TempDir() @@ -86,7 +108,7 @@ tools: for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - _, got, err := invokeCommandWithContext(context.Background(), tc.args) + got, err := invokeCommand(tc.args) if (err != nil) != tc.wantErr { t.Fatalf("got error %v, wantErr %v", err, tc.wantErr) } @@ -121,7 +143,7 @@ tools: } args := []string{"invoke", "bq-tool", "--tools-file", toolsFilePath} - _, _, err := invokeCommandWithContext(context.Background(), args) + _, err := invokeCommand(args) if err == nil { t.Fatal("expected error for tool requiring client auth, but got nil") } diff --git a/internal/cli/options.go b/internal/cli/options.go new file mode 100644 index 0000000000..b3e39b02e6 --- /dev/null +++ b/internal/cli/options.go @@ -0,0 +1,251 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cli + +import ( + "context" + "fmt" + "io" + "os" + "slices" + "strings" + + "github.com/googleapis/genai-toolbox/internal/log" + "github.com/googleapis/genai-toolbox/internal/prebuiltconfigs" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/telemetry" + "github.com/googleapis/genai-toolbox/internal/util" +) + +type IOStreams struct { + In io.Reader + Out io.Writer + ErrOut io.Writer +} + +// ToolboxOptions holds dependencies shared by all commands. +type ToolboxOptions struct { + IOStreams IOStreams + Logger log.Logger + Cfg server.ServerConfig + ToolsFile string + ToolsFiles []string + ToolsFolder string + PrebuiltConfigs []string +} + +// Option defines a function that modifies the ToolboxOptions struct. +type Option func(*ToolboxOptions) + +// NewToolboxOptions creates a new instance with defaults, then applies any +// provided options. +func NewToolboxOptions(opts ...Option) *ToolboxOptions { + o := &ToolboxOptions{ + IOStreams: IOStreams{ + In: os.Stdin, + Out: os.Stdout, + ErrOut: os.Stderr, + }, + } + + for _, opt := range opts { + opt(o) + } + return o +} + +// Apply allows you to update an EXISTING ToolboxOptions instance. +// This is useful for "late binding". +func (o *ToolboxOptions) Apply(opts ...Option) { + for _, opt := range opts { + opt(o) + } +} + +// WithIOStreams updates the IO streams. +func WithIOStreams(out, err io.Writer) Option { + return func(o *ToolboxOptions) { + o.IOStreams.Out = out + o.IOStreams.ErrOut = err + } +} + +// Setup create logger and telemetry instrumentations. +func (opts *ToolboxOptions) Setup(ctx context.Context) (context.Context, func(context.Context) error, error) { + // If stdio, set logger's out stream (usually DEBUG and INFO logs) to + // errStream + loggerOut := opts.IOStreams.Out + if opts.Cfg.Stdio { + loggerOut = opts.IOStreams.ErrOut + } + + // Handle logger separately from config + logger, err := log.NewLogger(opts.Cfg.LoggingFormat.String(), opts.Cfg.LogLevel.String(), loggerOut, opts.IOStreams.ErrOut) + if err != nil { + return ctx, nil, fmt.Errorf("unable to initialize logger: %w", err) + } + + ctx = util.WithLogger(ctx, logger) + opts.Logger = logger + + // Set up OpenTelemetry + otelShutdown, err := telemetry.SetupOTel(ctx, opts.Cfg.Version, opts.Cfg.TelemetryOTLP, opts.Cfg.TelemetryGCP, opts.Cfg.TelemetryServiceName) + if err != nil { + errMsg := fmt.Errorf("error setting up OpenTelemetry: %w", err) + logger.ErrorContext(ctx, errMsg.Error()) + return ctx, nil, errMsg + } + + shutdownFunc := func(ctx context.Context) error { + err := otelShutdown(ctx) + if err != nil { + errMsg := fmt.Errorf("error shutting down OpenTelemetry: %w", err) + logger.ErrorContext(ctx, errMsg.Error()) + return err + } + return nil + } + + instrumentation, err := telemetry.CreateTelemetryInstrumentation(opts.Cfg.Version) + if err != nil { + errMsg := fmt.Errorf("unable to create telemetry instrumentation: %w", err) + logger.ErrorContext(ctx, errMsg.Error()) + return ctx, shutdownFunc, errMsg + } + + ctx = util.WithInstrumentation(ctx, instrumentation) + + return ctx, shutdownFunc, nil +} + +// LoadConfig checks and merge files that should be loaded into the server +func (opts *ToolboxOptions) LoadConfig(ctx context.Context) (bool, error) { + // Determine if Custom Files should be loaded + // Check for explicit custom flags + isCustomConfigured := opts.ToolsFile != "" || len(opts.ToolsFiles) > 0 || opts.ToolsFolder != "" + + // Determine if default 'tools.yaml' should be used (No prebuilt AND No custom flags) + useDefaultToolsFile := len(opts.PrebuiltConfigs) == 0 && !isCustomConfigured + + if useDefaultToolsFile { + opts.ToolsFile = "tools.yaml" + isCustomConfigured = true + } + + logger, err := util.LoggerFromContext(ctx) + if err != nil { + return isCustomConfigured, err + } + + var allToolsFiles []ToolsFile + + // Load Prebuilt Configuration + + if len(opts.PrebuiltConfigs) > 0 { + slices.Sort(opts.PrebuiltConfigs) + sourcesList := strings.Join(opts.PrebuiltConfigs, ", ") + logMsg := fmt.Sprintf("Using prebuilt tool configurations for: %s", sourcesList) + logger.InfoContext(ctx, logMsg) + + for _, configName := range opts.PrebuiltConfigs { + buf, err := prebuiltconfigs.Get(configName) + if err != nil { + logger.ErrorContext(ctx, err.Error()) + return isCustomConfigured, err + } + + // Parse into ToolsFile struct + parsed, err := parseToolsFile(ctx, buf) + if err != nil { + errMsg := fmt.Errorf("unable to parse prebuilt tool configuration for '%s': %w", configName, err) + logger.ErrorContext(ctx, errMsg.Error()) + return isCustomConfigured, errMsg + } + allToolsFiles = append(allToolsFiles, parsed) + } + } + + // Load Custom Configurations + if isCustomConfigured { + // Enforce exclusivity among custom flags (tools-file vs tools-files vs tools-folder) + if (opts.ToolsFile != "" && len(opts.ToolsFiles) > 0) || + (opts.ToolsFile != "" && opts.ToolsFolder != "") || + (len(opts.ToolsFiles) > 0 && opts.ToolsFolder != "") { + errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously") + logger.ErrorContext(ctx, errMsg.Error()) + return isCustomConfigured, errMsg + } + + var customTools ToolsFile + var err error + + if len(opts.ToolsFiles) > 0 { + // Use tools-files + logger.InfoContext(ctx, fmt.Sprintf("Loading and merging %d tool configuration files", len(opts.ToolsFiles))) + customTools, err = LoadAndMergeToolsFiles(ctx, opts.ToolsFiles) + } else if opts.ToolsFolder != "" { + // Use tools-folder + logger.InfoContext(ctx, fmt.Sprintf("Loading and merging all YAML files from directory: %s", opts.ToolsFolder)) + customTools, err = LoadAndMergeToolsFolder(ctx, opts.ToolsFolder) + } else { + // Use single file (tools-file or default `tools.yaml`) + buf, readFileErr := os.ReadFile(opts.ToolsFile) + if readFileErr != nil { + errMsg := fmt.Errorf("unable to read tool file at %q: %w", opts.ToolsFile, readFileErr) + logger.ErrorContext(ctx, errMsg.Error()) + return isCustomConfigured, errMsg + } + customTools, err = parseToolsFile(ctx, buf) + if err != nil { + err = fmt.Errorf("unable to parse tool file at %q: %w", opts.ToolsFile, err) + } + } + + if err != nil { + logger.ErrorContext(ctx, err.Error()) + return isCustomConfigured, err + } + allToolsFiles = append(allToolsFiles, customTools) + } + + // Modify version string based on loaded configurations + if len(opts.PrebuiltConfigs) > 0 { + tag := "prebuilt" + if isCustomConfigured { + tag = "custom" + } + // prebuiltConfigs is already sorted above + for _, configName := range opts.PrebuiltConfigs { + opts.Cfg.Version += fmt.Sprintf("+%s.%s", tag, configName) + } + } + + // Merge Everything + // This will error if custom tools collide with prebuilt tools + finalToolsFile, err := mergeToolsFiles(allToolsFiles...) + if err != nil { + logger.ErrorContext(ctx, err.Error()) + return isCustomConfigured, err + } + + opts.Cfg.SourceConfigs = finalToolsFile.Sources + opts.Cfg.AuthServiceConfigs = finalToolsFile.AuthServices + opts.Cfg.EmbeddingModelConfigs = finalToolsFile.EmbeddingModels + opts.Cfg.ToolConfigs = finalToolsFile.Tools + opts.Cfg.ToolsetConfigs = finalToolsFile.Toolsets + opts.Cfg.PromptConfigs = finalToolsFile.Prompts + + return isCustomConfigured, nil +} diff --git a/cmd/options_test.go b/internal/cli/options_test.go similarity index 62% rename from cmd/options_test.go rename to internal/cli/options_test.go index e0ab779b52..ae1fd525fb 100644 --- a/cmd/options_test.go +++ b/internal/cli/options_test.go @@ -12,57 +12,38 @@ // See the License for the specific language governing permissions and // limitations under the License. -package cmd +package cli import ( "errors" "io" "testing" - - "github.com/spf13/cobra" ) -func TestCommandOptions(t *testing.T) { +func TestToolboxOptions(t *testing.T) { w := io.Discard tcs := []struct { desc string - isValid func(*Command) error + isValid func(*ToolboxOptions) error option Option }{ { desc: "with logger", - isValid: func(c *Command) error { - if c.outStream != w || c.errStream != w { + isValid: func(o *ToolboxOptions) error { + if o.IOStreams.Out != w || o.IOStreams.ErrOut != w { return errors.New("loggers do not match") } return nil }, - option: WithStreams(w, w), + option: WithIOStreams(w, w), }, } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - got, err := invokeProxyWithOption(tc.option) - if err != nil { - t.Fatal(err) - } + got := NewToolboxOptions(tc.option) if err := tc.isValid(got); err != nil { t.Errorf("option did not initialize command correctly: %v", err) } }) } } - -func invokeProxyWithOption(o Option) (*Command, error) { - c := NewCommand(o) - // Keep the test output quiet - c.SilenceUsage = true - c.SilenceErrors = true - // Disable execute behavior - c.RunE = func(*cobra.Command, []string) error { - return nil - } - - err := c.Execute() - return c, err -} diff --git a/internal/cli/persistent_flags.go b/internal/cli/persistent_flags.go new file mode 100644 index 0000000000..a94eccae32 --- /dev/null +++ b/internal/cli/persistent_flags.go @@ -0,0 +1,46 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cli + +import ( + "fmt" + "strings" + + "github.com/googleapis/genai-toolbox/internal/prebuiltconfigs" + "github.com/spf13/cobra" +) + +// PersistentFlags sets up flags that are available for all commands and +// subcommands +// It is also used to set up persistent flags during subcommand unit tests +func PersistentFlags(parentCmd *cobra.Command, opts *ToolboxOptions) { + persistentFlags := parentCmd.PersistentFlags() + + persistentFlags.StringVar(&opts.ToolsFile, "tools-file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.") + persistentFlags.StringSliceVar(&opts.ToolsFiles, "tools-files", []string{}, "Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --tools-file, or --tools-folder.") + persistentFlags.StringVar(&opts.ToolsFolder, "tools-folder", "", "Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --tools-file, or --tools-files.") + persistentFlags.Var(&opts.Cfg.LogLevel, "log-level", "Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'.") + persistentFlags.Var(&opts.Cfg.LoggingFormat, "logging-format", "Specify logging format to use. Allowed: 'standard' or 'JSON'.") + persistentFlags.BoolVar(&opts.Cfg.TelemetryGCP, "telemetry-gcp", false, "Enable exporting directly to Google Cloud Monitoring.") + persistentFlags.StringVar(&opts.Cfg.TelemetryOTLP, "telemetry-otlp", "", "Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318')") + persistentFlags.StringVar(&opts.Cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.") + // Fetch prebuilt tools sources to customize the help description + prebuiltHelp := fmt.Sprintf( + "Use a prebuilt tool configuration by source type. Allowed: '%s'. Can be specified multiple times.", + strings.Join(prebuiltconfigs.GetPrebuiltSources(), "', '"), + ) + persistentFlags.StringSliceVar(&opts.PrebuiltConfigs, "prebuilt", []string{}, prebuiltHelp) + persistentFlags.StringSliceVar(&opts.Cfg.UserAgentMetadata, "user-agent-metadata", []string{}, "Appends additional metadata to the User-Agent.") +} diff --git a/internal/cli/skills/command.go b/internal/cli/skills/command.go index d8b2d286a9..bf2cdbb85c 100644 --- a/internal/cli/skills/command.go +++ b/internal/cli/skills/command.go @@ -22,7 +22,7 @@ import ( "path/filepath" "sort" - "github.com/googleapis/genai-toolbox/internal/log" + "github.com/googleapis/genai-toolbox/internal/cli" "github.com/googleapis/genai-toolbox/internal/server" "github.com/googleapis/genai-toolbox/internal/server/resources" "github.com/googleapis/genai-toolbox/internal/tools" @@ -30,28 +30,9 @@ import ( "github.com/spf13/cobra" ) -// RootCommand defines the interface for required by skills-generate subcommand. -// This allows subcommands to access shared resources and functionality without -// direct coupling to the root command's implementation. -type RootCommand interface { - // Config returns a copy of the current server configuration. - Config() server.ServerConfig - - // LoadConfig loads and merges the configuration from files, folders, and prebuilts. - LoadConfig(ctx context.Context) error - - // Setup initializes the runtime environment, including logging and telemetry. - // It returns the updated context and a shutdown function to be called when finished. - Setup(ctx context.Context) (context.Context, func(context.Context) error, error) - - // Logger returns the logger instance. - Logger() log.Logger -} - // Command is the command for generating skills. -type Command struct { +type command struct { *cobra.Command - rootCmd RootCommand name string description string toolset string @@ -59,15 +40,13 @@ type Command struct { } // NewCommand creates a new Command. -func NewCommand(rootCmd RootCommand) *cobra.Command { - cmd := &Command{ - rootCmd: rootCmd, - } +func NewCommand(opts *cli.ToolboxOptions) *cobra.Command { + cmd := &command{} cmd.Command = &cobra.Command{ Use: "skills-generate", Short: "Generate skills from tool configurations", RunE: func(c *cobra.Command, args []string) error { - return cmd.run(c) + return run(cmd, opts) }, } @@ -81,11 +60,11 @@ func NewCommand(rootCmd RootCommand) *cobra.Command { return cmd.Command } -func (c *Command) run(cmd *cobra.Command) error { +func run(cmd *command, opts *cli.ToolboxOptions) error { ctx, cancel := context.WithCancel(cmd.Context()) defer cancel() - ctx, shutdown, err := c.rootCmd.Setup(ctx) + ctx, shutdown, err := opts.Setup(ctx) if err != nil { return err } @@ -93,39 +72,37 @@ func (c *Command) run(cmd *cobra.Command) error { _ = shutdown(ctx) }() - logger := c.rootCmd.Logger() - - // Load and merge tool configurations - if err := c.rootCmd.LoadConfig(ctx); err != nil { + _, err = opts.LoadConfig(ctx) + if err != nil { return err } - if err := os.MkdirAll(c.outputDir, 0755); err != nil { + if err := os.MkdirAll(cmd.outputDir, 0755); err != nil { errMsg := fmt.Errorf("error creating output directory: %w", err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } - logger.InfoContext(ctx, fmt.Sprintf("Generating skill '%s'...", c.name)) + opts.Logger.InfoContext(ctx, fmt.Sprintf("Generating skill '%s'...", cmd.name)) // Initialize toolbox and collect tools - allTools, err := c.collectTools(ctx) + allTools, err := cmd.collectTools(ctx, opts) if err != nil { errMsg := fmt.Errorf("error collecting tools: %w", err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } if len(allTools) == 0 { - logger.InfoContext(ctx, "No tools found to generate.") + opts.Logger.InfoContext(ctx, "No tools found to generate.") return nil } // Generate the combined skill directory - skillPath := filepath.Join(c.outputDir, c.name) + skillPath := filepath.Join(cmd.outputDir, cmd.name) if err := os.MkdirAll(skillPath, 0755); err != nil { errMsg := fmt.Errorf("error creating skill directory: %w", err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } @@ -133,7 +110,7 @@ func (c *Command) run(cmd *cobra.Command) error { assetsPath := filepath.Join(skillPath, "assets") if err := os.MkdirAll(assetsPath, 0755); err != nil { errMsg := fmt.Errorf("error creating assets dir: %w", err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } @@ -141,7 +118,7 @@ func (c *Command) run(cmd *cobra.Command) error { scriptsPath := filepath.Join(skillPath, "scripts") if err := os.MkdirAll(scriptsPath, 0755); err != nil { errMsg := fmt.Errorf("error creating scripts dir: %w", err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } @@ -154,10 +131,10 @@ func (c *Command) run(cmd *cobra.Command) error { for _, toolName := range toolNames { // Generate YAML config in asset directory - minimizedContent, err := generateToolConfigYAML(c.rootCmd.Config(), toolName) + minimizedContent, err := generateToolConfigYAML(opts.Cfg, toolName) if err != nil { errMsg := fmt.Errorf("error generating filtered config for %s: %w", toolName, err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } @@ -166,7 +143,7 @@ func (c *Command) run(cmd *cobra.Command) error { destPath := filepath.Join(assetsPath, specificToolsFileName) if err := os.WriteFile(destPath, minimizedContent, 0644); err != nil { errMsg := fmt.Errorf("error writing filtered config for %s: %w", toolName, err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } } @@ -175,40 +152,40 @@ func (c *Command) run(cmd *cobra.Command) error { scriptContent, err := generateScriptContent(toolName, specificToolsFileName) if err != nil { errMsg := fmt.Errorf("error generating script content for %s: %w", toolName, err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } scriptFilename := filepath.Join(scriptsPath, fmt.Sprintf("%s.js", toolName)) if err := os.WriteFile(scriptFilename, []byte(scriptContent), 0755); err != nil { errMsg := fmt.Errorf("error writing script %s: %w", scriptFilename, err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } } // Generate SKILL.md - skillContent, err := generateSkillMarkdown(c.name, c.description, allTools) + skillContent, err := generateSkillMarkdown(cmd.name, cmd.description, allTools) if err != nil { errMsg := fmt.Errorf("error generating SKILL.md content: %w", err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } skillMdPath := filepath.Join(skillPath, "SKILL.md") if err := os.WriteFile(skillMdPath, []byte(skillContent), 0644); err != nil { errMsg := fmt.Errorf("error writing SKILL.md: %w", err) - logger.ErrorContext(ctx, errMsg.Error()) + opts.Logger.ErrorContext(ctx, errMsg.Error()) return errMsg } - logger.InfoContext(ctx, fmt.Sprintf("Successfully generated skill '%s' with %d tools.", c.name, len(allTools))) + opts.Logger.InfoContext(ctx, fmt.Sprintf("Successfully generated skill '%s' with %d tools.", cmd.name, len(allTools))) return nil } -func (c *Command) collectTools(ctx context.Context) (map[string]tools.Tool, error) { +func (c *command) collectTools(ctx context.Context, opts *cli.ToolboxOptions) (map[string]tools.Tool, error) { // Initialize Resources - sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, c.rootCmd.Config()) + sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, opts.Cfg) if err != nil { return nil, fmt.Errorf("failed to initialize resources: %w", err) } diff --git a/cmd/skill_generate_test.go b/internal/cli/skills/command_test.go similarity index 87% rename from cmd/skill_generate_test.go rename to internal/cli/skills/command_test.go index 3b91dc590b..19edaf9393 100644 --- a/cmd/skill_generate_test.go +++ b/internal/cli/skills/command_test.go @@ -12,17 +12,36 @@ // See the License for the specific language governing permissions and // limitations under the License. -package cmd +package skills import ( - "context" + "bytes" "os" "path/filepath" "strings" "testing" - "time" + + "github.com/googleapis/genai-toolbox/internal/cli" + _ "github.com/googleapis/genai-toolbox/internal/sources/sqlite" + _ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqlitesql" + "github.com/spf13/cobra" ) +func invokeCommand(args []string) (string, error) { + parentCmd := &cobra.Command{Use: "toolbox"} + + buf := new(bytes.Buffer) + opts := cli.NewToolboxOptions(cli.WithIOStreams(buf, buf)) + cli.PersistentFlags(parentCmd, opts) + + cmd := NewCommand(opts) + parentCmd.AddCommand(cmd) + parentCmd.SetArgs(args) + + err := parentCmd.Execute() + return buf.String(), err +} + func TestGenerateSkill(t *testing.T) { // Create a temporary directory for tests tmpDir := t.TempDir() @@ -55,10 +74,7 @@ tools: "--description", "hello tool", } - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - _, got, err := invokeCommandWithContext(ctx, args) + got, err := invokeCommand(args) if err != nil { t.Fatalf("command failed: %v\nOutput: %s", err, got) } @@ -136,7 +152,7 @@ func TestGenerateSkill_NoConfig(t *testing.T) { "--description", "test", } - _, _, err := invokeCommandWithContext(context.Background(), args) + _, err := invokeCommand(args) if err == nil { t.Fatal("expected command to fail when no configuration is provided and tools.yaml is missing") } @@ -170,7 +186,7 @@ func TestGenerateSkill_MissingArguments(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, got, err := invokeCommandWithContext(context.Background(), tt.args) + got, err := invokeCommand(tt.args) if err == nil { t.Fatalf("expected command to fail due to missing arguments, but it succeeded\nOutput: %s", got) } diff --git a/internal/cli/tools_file.go b/internal/cli/tools_file.go new file mode 100644 index 0000000000..340a5d6642 --- /dev/null +++ b/internal/cli/tools_file.go @@ -0,0 +1,349 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cli + +import ( + "bytes" + "context" + "fmt" + "io" + "os" + "path/filepath" + "regexp" + "slices" + "strings" + + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/server" +) + +type ToolsFile struct { + Sources server.SourceConfigs `yaml:"sources"` + AuthServices server.AuthServiceConfigs `yaml:"authServices"` + EmbeddingModels server.EmbeddingModelConfigs `yaml:"embeddingModels"` + Tools server.ToolConfigs `yaml:"tools"` + Toolsets server.ToolsetConfigs `yaml:"toolsets"` + Prompts server.PromptConfigs `yaml:"prompts"` +} + +// parseEnv replaces environment variables ${ENV_NAME} with their values. +// also support ${ENV_NAME:default_value}. +func parseEnv(input string) (string, error) { + re := regexp.MustCompile(`\$\{(\w+)(:([^}]*))?\}`) + + var err error + output := re.ReplaceAllStringFunc(input, func(match string) string { + parts := re.FindStringSubmatch(match) + + // extract the variable name + variableName := parts[1] + if value, found := os.LookupEnv(variableName); found { + return value + } + if len(parts) >= 4 && parts[2] != "" { + return parts[3] + } + err = fmt.Errorf("environment variable not found: %q", variableName) + return "" + }) + return output, err +} + +// parseToolsFile parses the provided yaml into appropriate configs. +func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) { + var toolsFile ToolsFile + // Replace environment variables if found + output, err := parseEnv(string(raw)) + if err != nil { + return toolsFile, fmt.Errorf("error parsing environment variables: %s", err) + } + raw = []byte(output) + + raw, err = convertToolsFile(raw) + if err != nil { + return toolsFile, fmt.Errorf("error converting tools file: %s", err) + } + + // Parse contents + toolsFile.Sources, toolsFile.AuthServices, toolsFile.EmbeddingModels, toolsFile.Tools, toolsFile.Toolsets, toolsFile.Prompts, err = server.UnmarshalResourceConfig(ctx, raw) + if err != nil { + return toolsFile, err + } + return toolsFile, nil +} + +func convertToolsFile(raw []byte) ([]byte, error) { + var input yaml.MapSlice + decoder := yaml.NewDecoder(bytes.NewReader(raw), yaml.UseOrderedMap()) + + // convert to tools file v2 + var buf bytes.Buffer + encoder := yaml.NewEncoder(&buf) + + v1keys := []string{"sources", "authSources", "authServices", "embeddingModels", "tools", "toolsets", "prompts"} + for { + if err := decoder.Decode(&input); err != nil { + if err == io.EOF { + break + } + return nil, err + } + for _, item := range input { + key, ok := item.Key.(string) + if !ok { + return nil, fmt.Errorf("unexpected non-string key in input: %v", item.Key) + } + // check if the key is config file v1's key + if slices.Contains(v1keys, key) { + // check if value conversion to yaml.MapSlice successfully + // fields such as "tools" in toolsets might pass the first check but + // fail to convert to MapSlice + if slice, ok := item.Value.(yaml.MapSlice); ok { + // Deprecated: convert authSources to authServices + if key == "authSources" { + key = "authServices" + } + transformed, err := transformDocs(key, slice) + if err != nil { + return nil, err + } + // encode per-doc + for _, doc := range transformed { + if err := encoder.Encode(doc); err != nil { + return nil, err + } + } + } else { + // invalid input will be ignored + // we don't want to throw error here since the config could + // be valid but with a different order such as: + // --- + // tools: + // - tool_a + // kind: toolsets + // --- + continue + } + } else { + // this doc is already v2, encode to buf + if err := encoder.Encode(input); err != nil { + return nil, err + } + break + } + } + } + return buf.Bytes(), nil +} + +// transformDocs transforms the configuration file from v1 format to v2 +// yaml.MapSlice will preserve the order in a map +func transformDocs(kind string, input yaml.MapSlice) ([]yaml.MapSlice, error) { + var transformed []yaml.MapSlice + for _, entry := range input { + entryName, ok := entry.Key.(string) + if !ok { + return nil, fmt.Errorf("unexpected non-string key for entry in '%s': %v", kind, entry.Key) + } + entryBody := ProcessValue(entry.Value, kind == "toolsets") + + currentTransformed := yaml.MapSlice{ + {Key: "kind", Value: kind}, + {Key: "name", Value: entryName}, + } + + // Merge the transformed body into our result + if bodySlice, ok := entryBody.(yaml.MapSlice); ok { + currentTransformed = append(currentTransformed, bodySlice...) + } else { + return nil, fmt.Errorf("unable to convert entryBody to MapSlice") + } + transformed = append(transformed, currentTransformed) + } + return transformed, nil +} + +// ProcessValue recursively looks for MapSlices to rename 'kind' -> 'type' +func ProcessValue(v any, isToolset bool) any { + switch val := v.(type) { + case yaml.MapSlice: + // creating a new MapSlice is safer for recursive transformation + newVal := make(yaml.MapSlice, len(val)) + for i, item := range val { + // Perform renaming + if item.Key == "kind" { + item.Key = "type" + } + // Recursive call for nested values (e.g., nested objects or lists) + item.Value = ProcessValue(item.Value, false) + newVal[i] = item + } + return newVal + case []any: + // Process lists: If it's a toolset top-level list, wrap it. + if isToolset { + return yaml.MapSlice{{Key: "tools", Value: val}} + } + // Otherwise, recurse into list items (to catch nested objects) + newVal := make([]any, len(val)) + for i := range val { + newVal[i] = ProcessValue(val[i], false) + } + return newVal + default: + return val + } +} + +// mergeToolsFiles merges multiple ToolsFile structs into one. +// Detects and raises errors for resource conflicts in sources, authServices, tools, and toolsets. +// All resource names (sources, authServices, tools, toolsets) must be unique across all files. +func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) { + merged := ToolsFile{ + Sources: make(server.SourceConfigs), + AuthServices: make(server.AuthServiceConfigs), + EmbeddingModels: make(server.EmbeddingModelConfigs), + Tools: make(server.ToolConfigs), + Toolsets: make(server.ToolsetConfigs), + Prompts: make(server.PromptConfigs), + } + + var conflicts []string + + for fileIndex, file := range files { + // Check for conflicts and merge sources + for name, source := range file.Sources { + if _, exists := merged.Sources[name]; exists { + conflicts = append(conflicts, fmt.Sprintf("source '%s' (file #%d)", name, fileIndex+1)) + } else { + merged.Sources[name] = source + } + } + + // Check for conflicts and merge authServices + for name, authService := range file.AuthServices { + if _, exists := merged.AuthServices[name]; exists { + conflicts = append(conflicts, fmt.Sprintf("authService '%s' (file #%d)", name, fileIndex+1)) + } else { + merged.AuthServices[name] = authService + } + } + + // Check for conflicts and merge embeddingModels + for name, em := range file.EmbeddingModels { + if _, exists := merged.EmbeddingModels[name]; exists { + conflicts = append(conflicts, fmt.Sprintf("embedding model '%s' (file #%d)", name, fileIndex+1)) + } else { + merged.EmbeddingModels[name] = em + } + } + + // Check for conflicts and merge tools + for name, tool := range file.Tools { + if _, exists := merged.Tools[name]; exists { + conflicts = append(conflicts, fmt.Sprintf("tool '%s' (file #%d)", name, fileIndex+1)) + } else { + merged.Tools[name] = tool + } + } + + // Check for conflicts and merge toolsets + for name, toolset := range file.Toolsets { + if _, exists := merged.Toolsets[name]; exists { + conflicts = append(conflicts, fmt.Sprintf("toolset '%s' (file #%d)", name, fileIndex+1)) + } else { + merged.Toolsets[name] = toolset + } + } + + // Check for conflicts and merge prompts + for name, prompt := range file.Prompts { + if _, exists := merged.Prompts[name]; exists { + conflicts = append(conflicts, fmt.Sprintf("prompt '%s' (file #%d)", name, fileIndex+1)) + } else { + merged.Prompts[name] = prompt + } + } + } + + // If conflicts were detected, return an error + if len(conflicts) > 0 { + return ToolsFile{}, fmt.Errorf("resource conflicts detected:\n - %s\n\nPlease ensure each source, authService, tool, toolset and prompt has a unique name across all files", strings.Join(conflicts, "\n - ")) + } + + return merged, nil +} + +// LoadAndMergeToolsFiles loads multiple YAML files and merges them +func LoadAndMergeToolsFiles(ctx context.Context, filePaths []string) (ToolsFile, error) { + var toolsFiles []ToolsFile + + for _, filePath := range filePaths { + buf, err := os.ReadFile(filePath) + if err != nil { + return ToolsFile{}, fmt.Errorf("unable to read tool file at %q: %w", filePath, err) + } + + toolsFile, err := parseToolsFile(ctx, buf) + if err != nil { + return ToolsFile{}, fmt.Errorf("unable to parse tool file at %q: %w", filePath, err) + } + + toolsFiles = append(toolsFiles, toolsFile) + } + + mergedFile, err := mergeToolsFiles(toolsFiles...) + if err != nil { + return ToolsFile{}, fmt.Errorf("unable to merge tools files: %w", err) + } + + return mergedFile, nil +} + +// LoadAndMergeToolsFolder loads all YAML files from a directory and merges them +func LoadAndMergeToolsFolder(ctx context.Context, folderPath string) (ToolsFile, error) { + // Check if directory exists + info, err := os.Stat(folderPath) + if err != nil { + return ToolsFile{}, fmt.Errorf("unable to access tools folder at %q: %w", folderPath, err) + } + if !info.IsDir() { + return ToolsFile{}, fmt.Errorf("path %q is not a directory", folderPath) + } + + // Find all YAML files in the directory + pattern := filepath.Join(folderPath, "*.yaml") + yamlFiles, err := filepath.Glob(pattern) + if err != nil { + return ToolsFile{}, fmt.Errorf("error finding YAML files in %q: %w", folderPath, err) + } + + // Also find .yml files + ymlPattern := filepath.Join(folderPath, "*.yml") + ymlFiles, err := filepath.Glob(ymlPattern) + if err != nil { + return ToolsFile{}, fmt.Errorf("error finding YML files in %q: %w", folderPath, err) + } + + // Combine both file lists + allFiles := append(yamlFiles, ymlFiles...) + + if len(allFiles) == 0 { + return ToolsFile{}, fmt.Errorf("no YAML files found in directory %q", folderPath) + } + + // Use existing LoadAndMergeToolsFiles function + return LoadAndMergeToolsFiles(ctx, allFiles) +} diff --git a/internal/cli/tools_file_test.go b/internal/cli/tools_file_test.go new file mode 100644 index 0000000000..9562a3065a --- /dev/null +++ b/internal/cli/tools_file_test.go @@ -0,0 +1,2143 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cli + +import ( + "fmt" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/auth/google" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini" + "github.com/googleapis/genai-toolbox/internal/prebuiltconfigs" + "github.com/googleapis/genai-toolbox/internal/prompts" + "github.com/googleapis/genai-toolbox/internal/prompts/custom" + "github.com/googleapis/genai-toolbox/internal/server" + cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" + httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/tools/http" + "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + + _ "github.com/googleapis/genai-toolbox/cmd/imports" +) + +func TestParseEnv(t *testing.T) { + tcs := []struct { + desc string + env map[string]string + in string + want string + err bool + errString string + }{ + { + desc: "without default without env", + in: "${FOO}", + want: "", + err: true, + errString: `environment variable not found: "FOO"`, + }, + { + desc: "without default with env", + env: map[string]string{ + "FOO": "bar", + }, + in: "${FOO}", + want: "bar", + }, + { + desc: "with empty default", + in: "${FOO:}", + want: "", + }, + { + desc: "with default", + in: "${FOO:bar}", + want: "bar", + }, + { + desc: "with default with env", + env: map[string]string{ + "FOO": "hello", + }, + in: "${FOO:bar}", + want: "hello", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + if tc.env != nil { + for k, v := range tc.env { + t.Setenv(k, v) + } + } + got, err := parseEnv(tc.in) + if tc.err { + if err == nil { + t.Fatalf("expected error not found") + } + if tc.errString != err.Error() { + t.Fatalf("incorrect error string: got %s, want %s", err, tc.errString) + } + } + if tc.want != got { + t.Fatalf("unexpected want: got %s, want %s", got, tc.want) + } + }) + } +} + +func TestConvertToolsFile(t *testing.T) { + tcs := []struct { + desc string + in string + want string + isErr bool + errStr string + }{ + { + desc: "basic convert", + in: ` + sources: + my-pg-instance: + kind: cloud-sql-postgres + project: my-project + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass + authServices: + my-google-auth: + kind: google + clientId: testing-id + tools: + example_tool: + kind: postgres-sql + source: my-pg-instance + description: some description + statement: SELECT * FROM SQL_STATEMENT; + parameters: + - name: country + type: string + description: some description + toolsets: + example_toolset: + - example_tool + prompts: + code_review: + description: ask llm to analyze code quality + messages: + - content: "please review the following code for quality: {{.code}}" + arguments: + - name: code + description: the code to review + embeddingModels: + gemini-model: + kind: gemini + model: gemini-embedding-001 + apiKey: some-key + dimension: 768`, + want: `kind: sources +name: my-pg-instance +type: cloud-sql-postgres +project: my-project +region: my-region +instance: my-instance +database: my_db +user: my_user +password: my_pass +--- +kind: authServices +name: my-google-auth +type: google +clientId: testing-id +--- +kind: tools +name: example_tool +type: postgres-sql +source: my-pg-instance +description: some description +statement: SELECT * FROM SQL_STATEMENT; +parameters: +- name: country + type: string + description: some description +--- +kind: toolsets +name: example_toolset +tools: +- example_tool +--- +kind: prompts +name: code_review +description: ask llm to analyze code quality +messages: +- content: "please review the following code for quality: {{.code}}" +arguments: +- name: code + description: the code to review +--- +kind: embeddingModels +name: gemini-model +type: gemini +model: gemini-embedding-001 +apiKey: some-key +dimension: 768 +`, + }, + { + desc: "preserve resource order", + in: ` + tools: + example_tool: + kind: postgres-sql + source: my-pg-instance + description: some description + statement: SELECT * FROM SQL_STATEMENT; + parameters: + - name: country + type: string + description: some description + sources: + my-pg-instance: + kind: cloud-sql-postgres + project: my-project + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass + authServices: + my-google-auth: + kind: google + clientId: testing-id + toolsets: + example_toolset: + - example_tool + authSources: + my-google-auth2: + kind: google + clientId: testing-id`, + want: `kind: tools +name: example_tool +type: postgres-sql +source: my-pg-instance +description: some description +statement: SELECT * FROM SQL_STATEMENT; +parameters: +- name: country + type: string + description: some description +--- +kind: sources +name: my-pg-instance +type: cloud-sql-postgres +project: my-project +region: my-region +instance: my-instance +database: my_db +user: my_user +password: my_pass +--- +kind: authServices +name: my-google-auth +type: google +clientId: testing-id +--- +kind: toolsets +name: example_toolset +tools: +- example_tool +--- +kind: authServices +name: my-google-auth2 +type: google +clientId: testing-id +`, + }, + { + desc: "convert combination of v1 and v2", + in: ` + sources: + my-pg-instance: + kind: cloud-sql-postgres + project: my-project + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass + authServices: + my-google-auth: + kind: google + clientId: testing-id + tools: + example_tool: + kind: postgres-sql + source: my-pg-instance + description: some description + statement: SELECT * FROM SQL_STATEMENT; + parameters: + - name: country + type: string + description: some description + toolsets: + example_toolset: + - example_tool + prompts: + code_review: + description: ask llm to analyze code quality + messages: + - content: "please review the following code for quality: {{.code}}" + arguments: + - name: code + description: the code to review + embeddingModels: + gemini-model: + kind: gemini + model: gemini-embedding-001 + apiKey: some-key + dimension: 768 +--- + kind: sources + name: my-pg-instance2 + type: cloud-sql-postgres + project: my-project + region: my-region + instance: my-instance +--- + kind: authServices + name: my-google-auth2 + type: google + clientId: testing-id +--- + kind: tools + name: example_tool2 + type: postgres-sql + source: my-pg-instance + description: some description + statement: SELECT * FROM SQL_STATEMENT; + parameters: + - name: country + type: string + description: some description +--- + kind: toolsets + name: example_toolset2 + tools: + - example_tool +--- + tools: + - example_tool + kind: toolsets + name: example_toolset3 +--- + kind: prompts + name: code_review2 + description: ask llm to analyze code quality + messages: + - content: "please review the following code for quality: {{.code}}" + arguments: + - name: code + description: the code to review +--- + kind: embeddingModels + name: gemini-model2 + type: gemini`, + want: `kind: sources +name: my-pg-instance +type: cloud-sql-postgres +project: my-project +region: my-region +instance: my-instance +database: my_db +user: my_user +password: my_pass +--- +kind: authServices +name: my-google-auth +type: google +clientId: testing-id +--- +kind: tools +name: example_tool +type: postgres-sql +source: my-pg-instance +description: some description +statement: SELECT * FROM SQL_STATEMENT; +parameters: +- name: country + type: string + description: some description +--- +kind: toolsets +name: example_toolset +tools: +- example_tool +--- +kind: prompts +name: code_review +description: ask llm to analyze code quality +messages: +- content: "please review the following code for quality: {{.code}}" +arguments: +- name: code + description: the code to review +--- +kind: embeddingModels +name: gemini-model +type: gemini +model: gemini-embedding-001 +apiKey: some-key +dimension: 768 +--- +kind: sources +name: my-pg-instance2 +type: cloud-sql-postgres +project: my-project +region: my-region +instance: my-instance +--- +kind: authServices +name: my-google-auth2 +type: google +clientId: testing-id +--- +kind: tools +name: example_tool2 +type: postgres-sql +source: my-pg-instance +description: some description +statement: SELECT * FROM SQL_STATEMENT; +parameters: +- name: country + type: string + description: some description +--- +kind: toolsets +name: example_toolset2 +tools: +- example_tool +--- +tools: +- example_tool +kind: toolsets +name: example_toolset3 +--- +kind: prompts +name: code_review2 +description: ask llm to analyze code quality +messages: +- content: "please review the following code for quality: {{.code}}" +arguments: +- name: code + description: the code to review +--- +kind: embeddingModels +name: gemini-model2 +type: gemini +`, + }, + { + desc: "no convertion needed", + in: `kind: sources +name: my-pg-instance +type: cloud-sql-postgres +project: my-project +region: my-region +instance: my-instance +database: my_db +user: my_user +password: my_pass +--- +kind: tools +name: example_tool +type: postgres-sql +source: my-pg-instance +description: some description +statement: SELECT * FROM SQL_STATEMENT; +parameters: +- name: country + type: string + description: some description +--- +kind: toolsets +name: example_toolset +tools: +- example_tool`, + want: `kind: sources +name: my-pg-instance +type: cloud-sql-postgres +project: my-project +region: my-region +instance: my-instance +database: my_db +user: my_user +password: my_pass +--- +kind: tools +name: example_tool +type: postgres-sql +source: my-pg-instance +description: some description +statement: SELECT * FROM SQL_STATEMENT; +parameters: +- name: country + type: string + description: some description +--- +kind: toolsets +name: example_toolset +tools: +- example_tool +`, + }, + { + desc: "invalid source", + in: `sources: invalid`, + want: "", + }, + { + desc: "invalid toolset", + in: `toolsets: invalid`, + want: "", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + output, err := convertToolsFile([]byte(tc.in)) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if diff := cmp.Diff(string(output), tc.want); diff != "" { + t.Fatalf("incorrect toolsets parse: diff %v", diff) + } + }) + } +} + +func TestParseToolFile(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + description string + in string + wantToolsFile ToolsFile + }{ + { + description: "basic example tools file v1", + in: ` + sources: + my-pg-instance: + kind: cloud-sql-postgres + project: my-project + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass + tools: + example_tool: + kind: postgres-sql + source: my-pg-instance + description: some description + statement: | + SELECT * FROM SQL_STATEMENT; + parameters: + - name: country + type: string + description: some description + toolsets: + example_toolset: + - example_tool + `, + wantToolsFile: ToolsFile{ + Sources: server.SourceConfigs{ + "my-pg-instance": cloudsqlpgsrc.Config{ + Name: "my-pg-instance", + Type: cloudsqlpgsrc.SourceType, + Project: "my-project", + Region: "my-region", + Instance: "my-instance", + IPType: "public", + Database: "my_db", + User: "my_user", + Password: "my_pass", + }, + }, + Tools: server.ToolConfigs{ + "example_tool": postgressql.Config{ + Name: "example_tool", + Type: "postgres-sql", + Source: "my-pg-instance", + Description: "some description", + Statement: "SELECT * FROM SQL_STATEMENT;\n", + Parameters: []parameters.Parameter{ + parameters.NewStringParameter("country", "some description"), + }, + AuthRequired: []string{}, + }, + }, + Toolsets: server.ToolsetConfigs{ + "example_toolset": tools.ToolsetConfig{ + Name: "example_toolset", + ToolNames: []string{"example_tool"}, + }, + }, + AuthServices: nil, + Prompts: nil, + }, + }, + { + description: "basic example tools file v2", + in: ` + kind: sources + name: my-pg-instance + type: cloud-sql-postgres + project: my-project + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass +--- + kind: authServices + name: my-google-auth + type: google + clientId: testing-id +--- + kind: embeddingModels + name: gemini-model + type: gemini + model: gemini-embedding-001 + apiKey: some-key + dimension: 768 +--- + kind: tools + name: example_tool + type: postgres-sql + source: my-pg-instance + description: some description + statement: | + SELECT * FROM SQL_STATEMENT; + parameters: + - name: country + type: string + description: some description +--- + kind: toolsets + name: example_toolset + tools: + - example_tool +--- + kind: prompts + name: code_review + description: ask llm to analyze code quality + messages: + - content: "please review the following code for quality: {{.code}}" + arguments: + - name: code + description: the code to review + `, + wantToolsFile: ToolsFile{ + Sources: server.SourceConfigs{ + "my-pg-instance": cloudsqlpgsrc.Config{ + Name: "my-pg-instance", + Type: cloudsqlpgsrc.SourceType, + Project: "my-project", + Region: "my-region", + Instance: "my-instance", + IPType: "public", + Database: "my_db", + User: "my_user", + Password: "my_pass", + }, + }, + AuthServices: server.AuthServiceConfigs{ + "my-google-auth": google.Config{ + Name: "my-google-auth", + Type: google.AuthServiceType, + ClientID: "testing-id", + }, + }, + EmbeddingModels: server.EmbeddingModelConfigs{ + "gemini-model": gemini.Config{ + Name: "gemini-model", + Type: gemini.EmbeddingModelType, + Model: "gemini-embedding-001", + ApiKey: "some-key", + Dimension: 768, + }, + }, + Tools: server.ToolConfigs{ + "example_tool": postgressql.Config{ + Name: "example_tool", + Type: "postgres-sql", + Source: "my-pg-instance", + Description: "some description", + Statement: "SELECT * FROM SQL_STATEMENT;\n", + Parameters: []parameters.Parameter{ + parameters.NewStringParameter("country", "some description"), + }, + AuthRequired: []string{}, + }, + }, + Toolsets: server.ToolsetConfigs{ + "example_toolset": tools.ToolsetConfig{ + Name: "example_toolset", + ToolNames: []string{"example_tool"}, + }, + }, + Prompts: server.PromptConfigs{ + "code_review": &custom.Config{ + Name: "code_review", + Description: "ask llm to analyze code quality", + Arguments: prompts.Arguments{ + {Parameter: parameters.NewStringParameter("code", "the code to review")}, + }, + Messages: []prompts.Message{ + {Role: "user", Content: "please review the following code for quality: {{.code}}"}, + }, + }, + }, + }, + }, + { + description: "only prompts", + in: ` + kind: prompts + name: my-prompt + description: A prompt template for data analysis. + arguments: + - name: country + description: The country to analyze. + messages: + - content: Analyze the data for {{.country}}. + `, + wantToolsFile: ToolsFile{ + Sources: nil, + AuthServices: nil, + Tools: nil, + Toolsets: nil, + Prompts: server.PromptConfigs{ + "my-prompt": &custom.Config{ + Name: "my-prompt", + Description: "A prompt template for data analysis.", + Arguments: prompts.Arguments{ + {Parameter: parameters.NewStringParameter("country", "The country to analyze.")}, + }, + Messages: []prompts.Message{ + {Role: "user", Content: "Analyze the data for {{.country}}."}, + }, + }, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.description, func(t *testing.T) { + toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("failed to parse input: %v", err) + } + if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" { + t.Fatalf("incorrect sources parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" { + t.Fatalf("incorrect authServices parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" { + t.Fatalf("incorrect tools parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.Toolsets, toolsFile.Toolsets); diff != "" { + t.Fatalf("incorrect toolsets parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.Prompts, toolsFile.Prompts); diff != "" { + t.Fatalf("incorrect prompts parse: diff %v", diff) + } + }) + } +} + +func TestParseToolFileWithAuth(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + description string + in string + wantToolsFile ToolsFile + }{ + { + description: "basic example", + in: ` + kind: sources + name: my-pg-instance + type: cloud-sql-postgres + project: my-project + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass +--- + kind: authServices + name: my-google-service + type: google + clientId: my-client-id +--- + kind: authServices + name: other-google-service + type: google + clientId: other-client-id +--- + kind: tools + name: example_tool + type: postgres-sql + source: my-pg-instance + description: some description + statement: | + SELECT * FROM SQL_STATEMENT; + parameters: + - name: country + type: string + description: some description + - name: id + type: integer + description: user id + authServices: + - name: my-google-service + field: user_id + - name: email + type: string + description: user email + authServices: + - name: my-google-service + field: email + - name: other-google-service + field: other_email +--- + kind: toolsets + name: example_toolset + tools: + - example_tool + `, + wantToolsFile: ToolsFile{ + Sources: server.SourceConfigs{ + "my-pg-instance": cloudsqlpgsrc.Config{ + Name: "my-pg-instance", + Type: cloudsqlpgsrc.SourceType, + Project: "my-project", + Region: "my-region", + Instance: "my-instance", + IPType: "public", + Database: "my_db", + User: "my_user", + Password: "my_pass", + }, + }, + AuthServices: server.AuthServiceConfigs{ + "my-google-service": google.Config{ + Name: "my-google-service", + Type: google.AuthServiceType, + ClientID: "my-client-id", + }, + "other-google-service": google.Config{ + Name: "other-google-service", + Type: google.AuthServiceType, + ClientID: "other-client-id", + }, + }, + Tools: server.ToolConfigs{ + "example_tool": postgressql.Config{ + Name: "example_tool", + Type: "postgres-sql", + Source: "my-pg-instance", + Description: "some description", + Statement: "SELECT * FROM SQL_STATEMENT;\n", + AuthRequired: []string{}, + Parameters: []parameters.Parameter{ + parameters.NewStringParameter("country", "some description"), + parameters.NewIntParameterWithAuth("id", "user id", []parameters.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}), + parameters.NewStringParameterWithAuth("email", "user email", []parameters.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}), + }, + }, + }, + Toolsets: server.ToolsetConfigs{ + "example_toolset": tools.ToolsetConfig{ + Name: "example_toolset", + ToolNames: []string{"example_tool"}, + }, + }, + Prompts: nil, + }, + }, + { + description: "basic example with authSources", + in: ` + sources: + my-pg-instance: + kind: cloud-sql-postgres + project: my-project + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass + authSources: + my-google-service: + kind: google + clientId: my-client-id + other-google-service: + kind: google + clientId: other-client-id + + tools: + example_tool: + kind: postgres-sql + source: my-pg-instance + description: some description + statement: | + SELECT * FROM SQL_STATEMENT; + parameters: + - name: country + type: string + description: some description + - name: id + type: integer + description: user id + authSources: + - name: my-google-service + field: user_id + - name: email + type: string + description: user email + authSources: + - name: my-google-service + field: email + - name: other-google-service + field: other_email + + toolsets: + example_toolset: + - example_tool + `, + wantToolsFile: ToolsFile{ + Sources: server.SourceConfigs{ + "my-pg-instance": cloudsqlpgsrc.Config{ + Name: "my-pg-instance", + Type: cloudsqlpgsrc.SourceType, + Project: "my-project", + Region: "my-region", + Instance: "my-instance", + IPType: "public", + Database: "my_db", + User: "my_user", + Password: "my_pass", + }, + }, + AuthServices: server.AuthServiceConfigs{ + "my-google-service": google.Config{ + Name: "my-google-service", + Type: google.AuthServiceType, + ClientID: "my-client-id", + }, + "other-google-service": google.Config{ + Name: "other-google-service", + Type: google.AuthServiceType, + ClientID: "other-client-id", + }, + }, + Tools: server.ToolConfigs{ + "example_tool": postgressql.Config{ + Name: "example_tool", + Type: "postgres-sql", + Source: "my-pg-instance", + Description: "some description", + Statement: "SELECT * FROM SQL_STATEMENT;\n", + AuthRequired: []string{}, + Parameters: []parameters.Parameter{ + parameters.NewStringParameter("country", "some description"), + parameters.NewIntParameterWithAuth("id", "user id", []parameters.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}), + parameters.NewStringParameterWithAuth("email", "user email", []parameters.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}), + }, + }, + }, + Toolsets: server.ToolsetConfigs{ + "example_toolset": tools.ToolsetConfig{ + Name: "example_toolset", + ToolNames: []string{"example_tool"}, + }, + }, + Prompts: nil, + }, + }, + { + description: "basic example with authRequired", + in: ` + kind: sources + name: my-pg-instance + type: cloud-sql-postgres + project: my-project + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass +--- + kind: authServices + name: my-google-service + type: google + clientId: my-client-id +--- + kind: authServices + name: other-google-service + type: google + clientId: other-client-id +--- + kind: tools + name: example_tool + type: postgres-sql + source: my-pg-instance + description: some description + statement: | + SELECT * FROM SQL_STATEMENT; + authRequired: + - my-google-service + parameters: + - name: country + type: string + description: some description + - name: id + type: integer + description: user id + authServices: + - name: my-google-service + field: user_id + - name: email + type: string + description: user email + authServices: + - name: my-google-service + field: email + - name: other-google-service + field: other_email +--- + kind: toolsets + name: example_toolset + tools: + - example_tool + `, + wantToolsFile: ToolsFile{ + Sources: server.SourceConfigs{ + "my-pg-instance": cloudsqlpgsrc.Config{ + Name: "my-pg-instance", + Type: cloudsqlpgsrc.SourceType, + Project: "my-project", + Region: "my-region", + Instance: "my-instance", + IPType: "public", + Database: "my_db", + User: "my_user", + Password: "my_pass", + }, + }, + AuthServices: server.AuthServiceConfigs{ + "my-google-service": google.Config{ + Name: "my-google-service", + Type: google.AuthServiceType, + ClientID: "my-client-id", + }, + "other-google-service": google.Config{ + Name: "other-google-service", + Type: google.AuthServiceType, + ClientID: "other-client-id", + }, + }, + Tools: server.ToolConfigs{ + "example_tool": postgressql.Config{ + Name: "example_tool", + Type: "postgres-sql", + Source: "my-pg-instance", + Description: "some description", + Statement: "SELECT * FROM SQL_STATEMENT;\n", + AuthRequired: []string{"my-google-service"}, + Parameters: []parameters.Parameter{ + parameters.NewStringParameter("country", "some description"), + parameters.NewIntParameterWithAuth("id", "user id", []parameters.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}), + parameters.NewStringParameterWithAuth("email", "user email", []parameters.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}), + }, + }, + }, + Toolsets: server.ToolsetConfigs{ + "example_toolset": tools.ToolsetConfig{ + Name: "example_toolset", + ToolNames: []string{"example_tool"}, + }, + }, + Prompts: nil, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.description, func(t *testing.T) { + toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("failed to parse input: %v", err) + } + if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" { + t.Fatalf("incorrect sources parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" { + t.Fatalf("incorrect authServices parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" { + t.Fatalf("incorrect tools parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.Toolsets, toolsFile.Toolsets); diff != "" { + t.Fatalf("incorrect toolsets parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.Prompts, toolsFile.Prompts); diff != "" { + t.Fatalf("incorrect prompts parse: diff %v", diff) + } + }) + } + +} + +func TestEnvVarReplacement(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + t.Setenv("TestHeader", "ACTUAL_HEADER") + t.Setenv("API_KEY", "ACTUAL_API_KEY") + t.Setenv("clientId", "ACTUAL_CLIENT_ID") + t.Setenv("clientId2", "ACTUAL_CLIENT_ID_2") + t.Setenv("toolset_name", "ACTUAL_TOOLSET_NAME") + t.Setenv("cat_string", "cat") + t.Setenv("food_string", "food") + t.Setenv("TestHeader", "ACTUAL_HEADER") + t.Setenv("prompt_name", "ACTUAL_PROMPT_NAME") + t.Setenv("prompt_content", "ACTUAL_CONTENT") + + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + description string + in string + wantToolsFile ToolsFile + }{ + { + description: "file with env var example", + in: ` + sources: + my-http-instance: + kind: http + baseUrl: http://test_server/ + timeout: 10s + headers: + Authorization: ${TestHeader} + queryParams: + api-key: ${API_KEY} + authServices: + my-google-service: + kind: google + clientId: ${clientId} + other-google-service: + kind: google + clientId: ${clientId2} + + tools: + example_tool: + kind: http + source: my-instance + method: GET + path: "search?name=alice&pet=${cat_string}" + description: some description + authRequired: + - my-google-auth-service + - other-auth-service + queryParams: + - name: country + type: string + description: some description + authServices: + - name: my-google-auth-service + field: user_id + - name: other-auth-service + field: user_id + requestBody: | + { + "age": {{.age}}, + "city": "{{.city}}", + "food": "${food_string}", + "other": "$OTHER" + } + bodyParams: + - name: age + type: integer + description: age num + - name: city + type: string + description: city string + headers: + Authorization: API_KEY + Content-Type: application/json + headerParams: + - name: Language + type: string + description: language string + + toolsets: + ${toolset_name}: + - example_tool + + + prompts: + ${prompt_name}: + description: A test prompt for {{.name}}. + messages: + - role: user + content: ${prompt_content} + `, + wantToolsFile: ToolsFile{ + Sources: server.SourceConfigs{ + "my-http-instance": httpsrc.Config{ + Name: "my-http-instance", + Type: httpsrc.SourceType, + BaseURL: "http://test_server/", + Timeout: "10s", + DefaultHeaders: map[string]string{"Authorization": "ACTUAL_HEADER"}, + QueryParams: map[string]string{"api-key": "ACTUAL_API_KEY"}, + }, + }, + AuthServices: server.AuthServiceConfigs{ + "my-google-service": google.Config{ + Name: "my-google-service", + Type: google.AuthServiceType, + ClientID: "ACTUAL_CLIENT_ID", + }, + "other-google-service": google.Config{ + Name: "other-google-service", + Type: google.AuthServiceType, + ClientID: "ACTUAL_CLIENT_ID_2", + }, + }, + Tools: server.ToolConfigs{ + "example_tool": http.Config{ + Name: "example_tool", + Type: "http", + Source: "my-instance", + Method: "GET", + Path: "search?name=alice&pet=cat", + Description: "some description", + AuthRequired: []string{"my-google-auth-service", "other-auth-service"}, + QueryParams: []parameters.Parameter{ + parameters.NewStringParameterWithAuth("country", "some description", + []parameters.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"}, + {Name: "other-auth-service", Field: "user_id"}}), + }, + RequestBody: `{ + "age": {{.age}}, + "city": "{{.city}}", + "food": "food", + "other": "$OTHER" +} +`, + BodyParams: []parameters.Parameter{parameters.NewIntParameter("age", "age num"), parameters.NewStringParameter("city", "city string")}, + Headers: map[string]string{"Authorization": "API_KEY", "Content-Type": "application/json"}, + HeaderParams: []parameters.Parameter{parameters.NewStringParameter("Language", "language string")}, + }, + }, + Toolsets: server.ToolsetConfigs{ + "ACTUAL_TOOLSET_NAME": tools.ToolsetConfig{ + Name: "ACTUAL_TOOLSET_NAME", + ToolNames: []string{"example_tool"}, + }, + }, + Prompts: server.PromptConfigs{ + "ACTUAL_PROMPT_NAME": &custom.Config{ + Name: "ACTUAL_PROMPT_NAME", + Description: "A test prompt for {{.name}}.", + Messages: []prompts.Message{ + { + Role: "user", + Content: "ACTUAL_CONTENT", + }, + }, + Arguments: nil, + }, + }, + }, + }, + { + description: "file with env var example toolsfile v2", + in: ` + kind: sources + name: my-http-instance + type: http + baseUrl: http://test_server/ + timeout: 10s + headers: + Authorization: ${TestHeader} + queryParams: + api-key: ${API_KEY} +--- + kind: authServices + name: my-google-service + type: google + clientId: ${clientId} +--- + kind: authServices + name: other-google-service + type: google + clientId: ${clientId2} +--- + kind: tools + name: example_tool + type: http + source: my-instance + method: GET + path: "search?name=alice&pet=${cat_string}" + description: some description + authRequired: + - my-google-auth-service + - other-auth-service + queryParams: + - name: country + type: string + description: some description + authServices: + - name: my-google-auth-service + field: user_id + - name: other-auth-service + field: user_id + requestBody: | + { + "age": {{.age}}, + "city": "{{.city}}", + "food": "${food_string}", + "other": "$OTHER" + } + bodyParams: + - name: age + type: integer + description: age num + - name: city + type: string + description: city string + headers: + Authorization: API_KEY + Content-Type: application/json + headerParams: + - name: Language + type: string + description: language string +--- + kind: toolsets + name: ${toolset_name} + tools: + - example_tool +--- + kind: prompts + name: ${prompt_name} + description: A test prompt for {{.name}}. + messages: + - role: user + content: ${prompt_content} + `, + wantToolsFile: ToolsFile{ + Sources: server.SourceConfigs{ + "my-http-instance": httpsrc.Config{ + Name: "my-http-instance", + Type: httpsrc.SourceType, + BaseURL: "http://test_server/", + Timeout: "10s", + DefaultHeaders: map[string]string{"Authorization": "ACTUAL_HEADER"}, + QueryParams: map[string]string{"api-key": "ACTUAL_API_KEY"}, + }, + }, + AuthServices: server.AuthServiceConfigs{ + "my-google-service": google.Config{ + Name: "my-google-service", + Type: google.AuthServiceType, + ClientID: "ACTUAL_CLIENT_ID", + }, + "other-google-service": google.Config{ + Name: "other-google-service", + Type: google.AuthServiceType, + ClientID: "ACTUAL_CLIENT_ID_2", + }, + }, + Tools: server.ToolConfigs{ + "example_tool": http.Config{ + Name: "example_tool", + Type: "http", + Source: "my-instance", + Method: "GET", + Path: "search?name=alice&pet=cat", + Description: "some description", + AuthRequired: []string{"my-google-auth-service", "other-auth-service"}, + QueryParams: []parameters.Parameter{ + parameters.NewStringParameterWithAuth("country", "some description", + []parameters.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"}, + {Name: "other-auth-service", Field: "user_id"}}), + }, + RequestBody: `{ + "age": {{.age}}, + "city": "{{.city}}", + "food": "food", + "other": "$OTHER" +} +`, + BodyParams: []parameters.Parameter{parameters.NewIntParameter("age", "age num"), parameters.NewStringParameter("city", "city string")}, + Headers: map[string]string{"Authorization": "API_KEY", "Content-Type": "application/json"}, + HeaderParams: []parameters.Parameter{parameters.NewStringParameter("Language", "language string")}, + }, + }, + Toolsets: server.ToolsetConfigs{ + "ACTUAL_TOOLSET_NAME": tools.ToolsetConfig{ + Name: "ACTUAL_TOOLSET_NAME", + ToolNames: []string{"example_tool"}, + }, + }, + Prompts: server.PromptConfigs{ + "ACTUAL_PROMPT_NAME": &custom.Config{ + Name: "ACTUAL_PROMPT_NAME", + Description: "A test prompt for {{.name}}.", + Messages: []prompts.Message{ + { + Role: "user", + Content: "ACTUAL_CONTENT", + }, + }, + Arguments: nil, + }, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.description, func(t *testing.T) { + toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("failed to parse input: %v", err) + } + if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" { + t.Fatalf("incorrect sources parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" { + t.Fatalf("incorrect authServices parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" { + t.Fatalf("incorrect tools parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.Toolsets, toolsFile.Toolsets); diff != "" { + t.Fatalf("incorrect toolsets parse: diff %v", diff) + } + if diff := cmp.Diff(tc.wantToolsFile.Prompts, toolsFile.Prompts); diff != "" { + t.Fatalf("incorrect prompts parse: diff %v", diff) + } + }) + } +} + +func TestPrebuiltTools(t *testing.T) { + // Get prebuilt configs + alloydb_omni_config, _ := prebuiltconfigs.Get("alloydb-omni") + alloydb_admin_config, _ := prebuiltconfigs.Get("alloydb-postgres-admin") + alloydb_config, _ := prebuiltconfigs.Get("alloydb-postgres") + bigquery_config, _ := prebuiltconfigs.Get("bigquery") + clickhouse_config, _ := prebuiltconfigs.Get("clickhouse") + cloudsqlpg_config, _ := prebuiltconfigs.Get("cloud-sql-postgres") + cloudsqlpg_admin_config, _ := prebuiltconfigs.Get("cloud-sql-postgres-admin") + cloudsqlmysql_config, _ := prebuiltconfigs.Get("cloud-sql-mysql") + cloudsqlmysql_admin_config, _ := prebuiltconfigs.Get("cloud-sql-mysql-admin") + cloudsqlmssql_config, _ := prebuiltconfigs.Get("cloud-sql-mssql") + cloudsqlmssql_admin_config, _ := prebuiltconfigs.Get("cloud-sql-mssql-admin") + dataplex_config, _ := prebuiltconfigs.Get("dataplex") + firestoreconfig, _ := prebuiltconfigs.Get("firestore") + mysql_config, _ := prebuiltconfigs.Get("mysql") + mssql_config, _ := prebuiltconfigs.Get("mssql") + looker_config, _ := prebuiltconfigs.Get("looker") + lookerca_config, _ := prebuiltconfigs.Get("looker-conversational-analytics") + postgresconfig, _ := prebuiltconfigs.Get("postgres") + spanner_config, _ := prebuiltconfigs.Get("spanner") + spannerpg_config, _ := prebuiltconfigs.Get("spanner-postgres") + mindsdb_config, _ := prebuiltconfigs.Get("mindsdb") + sqlite_config, _ := prebuiltconfigs.Get("sqlite") + neo4jconfig, _ := prebuiltconfigs.Get("neo4j") + alloydbobsvconfig, _ := prebuiltconfigs.Get("alloydb-postgres-observability") + cloudsqlpgobsvconfig, _ := prebuiltconfigs.Get("cloud-sql-postgres-observability") + cloudsqlmysqlobsvconfig, _ := prebuiltconfigs.Get("cloud-sql-mysql-observability") + cloudsqlmssqlobsvconfig, _ := prebuiltconfigs.Get("cloud-sql-mssql-observability") + serverless_spark_config, _ := prebuiltconfigs.Get("serverless-spark") + cloudhealthcare_config, _ := prebuiltconfigs.Get("cloud-healthcare") + snowflake_config, _ := prebuiltconfigs.Get("snowflake") + + // Set environment variables + t.Setenv("API_KEY", "your_api_key") + + t.Setenv("BIGQUERY_PROJECT", "your_gcp_project_id") + t.Setenv("DATAPLEX_PROJECT", "your_gcp_project_id") + t.Setenv("FIRESTORE_PROJECT", "your_gcp_project_id") + t.Setenv("FIRESTORE_DATABASE", "your_firestore_db_name") + + t.Setenv("SPANNER_PROJECT", "your_gcp_project_id") + t.Setenv("SPANNER_INSTANCE", "your_spanner_instance") + t.Setenv("SPANNER_DATABASE", "your_spanner_db") + + t.Setenv("ALLOYDB_POSTGRES_PROJECT", "your_gcp_project_id") + t.Setenv("ALLOYDB_POSTGRES_REGION", "your_gcp_region") + t.Setenv("ALLOYDB_POSTGRES_CLUSTER", "your_alloydb_cluster") + t.Setenv("ALLOYDB_POSTGRES_INSTANCE", "your_alloydb_instance") + t.Setenv("ALLOYDB_POSTGRES_DATABASE", "your_alloydb_db") + t.Setenv("ALLOYDB_POSTGRES_USER", "your_alloydb_user") + t.Setenv("ALLOYDB_POSTGRES_PASSWORD", "your_alloydb_password") + + t.Setenv("ALLOYDB_OMNI_HOST", "localhost") + t.Setenv("ALLOYDB_OMNI_PORT", "5432") + t.Setenv("ALLOYDB_OMNI_DATABASE", "your_alloydb_db") + t.Setenv("ALLOYDB_OMNI_USER", "your_alloydb_user") + t.Setenv("ALLOYDB_OMNI_PASSWORD", "your_alloydb_password") + + t.Setenv("CLICKHOUSE_PROTOCOL", "your_clickhouse_protocol") + t.Setenv("CLICKHOUSE_DATABASE", "your_clickhouse_database") + t.Setenv("CLICKHOUSE_PASSWORD", "your_clickhouse_password") + t.Setenv("CLICKHOUSE_USER", "your_clickhouse_user") + t.Setenv("CLICKHOUSE_HOST", "your_clickhosue_host") + t.Setenv("CLICKHOUSE_PORT", "8123") + + t.Setenv("CLOUD_SQL_POSTGRES_PROJECT", "your_pg_project") + t.Setenv("CLOUD_SQL_POSTGRES_INSTANCE", "your_pg_instance") + t.Setenv("CLOUD_SQL_POSTGRES_DATABASE", "your_pg_db") + t.Setenv("CLOUD_SQL_POSTGRES_REGION", "your_pg_region") + t.Setenv("CLOUD_SQL_POSTGRES_USER", "your_pg_user") + t.Setenv("CLOUD_SQL_POSTGRES_PASS", "your_pg_pass") + + t.Setenv("CLOUD_SQL_MYSQL_PROJECT", "your_gcp_project_id") + t.Setenv("CLOUD_SQL_MYSQL_REGION", "your_gcp_region") + t.Setenv("CLOUD_SQL_MYSQL_INSTANCE", "your_instance") + t.Setenv("CLOUD_SQL_MYSQL_DATABASE", "your_cloudsql_mysql_db") + t.Setenv("CLOUD_SQL_MYSQL_USER", "your_cloudsql_mysql_user") + t.Setenv("CLOUD_SQL_MYSQL_PASSWORD", "your_cloudsql_mysql_password") + + t.Setenv("CLOUD_SQL_MSSQL_PROJECT", "your_gcp_project_id") + t.Setenv("CLOUD_SQL_MSSQL_REGION", "your_gcp_region") + t.Setenv("CLOUD_SQL_MSSQL_INSTANCE", "your_cloudsql_mssql_instance") + t.Setenv("CLOUD_SQL_MSSQL_DATABASE", "your_cloudsql_mssql_db") + t.Setenv("CLOUD_SQL_MSSQL_IP_ADDRESS", "127.0.0.1") + t.Setenv("CLOUD_SQL_MSSQL_USER", "your_cloudsql_mssql_user") + t.Setenv("CLOUD_SQL_MSSQL_PASSWORD", "your_cloudsql_mssql_password") + t.Setenv("CLOUD_SQL_POSTGRES_PASSWORD", "your_cloudsql_pg_password") + + t.Setenv("SERVERLESS_SPARK_PROJECT", "your_gcp_project_id") + t.Setenv("SERVERLESS_SPARK_LOCATION", "your_gcp_location") + + t.Setenv("POSTGRES_HOST", "localhost") + t.Setenv("POSTGRES_PORT", "5432") + t.Setenv("POSTGRES_DATABASE", "your_postgres_db") + t.Setenv("POSTGRES_USER", "your_postgres_user") + t.Setenv("POSTGRES_PASSWORD", "your_postgres_password") + + t.Setenv("MYSQL_HOST", "localhost") + t.Setenv("MYSQL_PORT", "3306") + t.Setenv("MYSQL_DATABASE", "your_mysql_db") + t.Setenv("MYSQL_USER", "your_mysql_user") + t.Setenv("MYSQL_PASSWORD", "your_mysql_password") + + t.Setenv("MSSQL_HOST", "localhost") + t.Setenv("MSSQL_PORT", "1433") + t.Setenv("MSSQL_DATABASE", "your_mssql_db") + t.Setenv("MSSQL_USER", "your_mssql_user") + t.Setenv("MSSQL_PASSWORD", "your_mssql_password") + + t.Setenv("MINDSDB_HOST", "localhost") + t.Setenv("MINDSDB_PORT", "47334") + t.Setenv("MINDSDB_DATABASE", "your_mindsdb_db") + t.Setenv("MINDSDB_USER", "your_mindsdb_user") + t.Setenv("MINDSDB_PASS", "your_mindsdb_password") + + t.Setenv("LOOKER_BASE_URL", "https://your_company.looker.com") + t.Setenv("LOOKER_CLIENT_ID", "your_looker_client_id") + t.Setenv("LOOKER_CLIENT_SECRET", "your_looker_client_secret") + t.Setenv("LOOKER_VERIFY_SSL", "true") + + t.Setenv("LOOKER_PROJECT", "your_project_id") + t.Setenv("LOOKER_LOCATION", "us") + + t.Setenv("SQLITE_DATABASE", "test.db") + + t.Setenv("NEO4J_URI", "bolt://localhost:7687") + t.Setenv("NEO4J_DATABASE", "neo4j") + t.Setenv("NEO4J_USERNAME", "your_neo4j_user") + t.Setenv("NEO4J_PASSWORD", "your_neo4j_password") + + t.Setenv("CLOUD_HEALTHCARE_PROJECT", "your_gcp_project_id") + t.Setenv("CLOUD_HEALTHCARE_REGION", "your_gcp_region") + t.Setenv("CLOUD_HEALTHCARE_DATASET", "your_healthcare_dataset") + + t.Setenv("SNOWFLAKE_ACCOUNT", "your_account") + t.Setenv("SNOWFLAKE_USER", "your_username") + t.Setenv("SNOWFLAKE_PASSWORD", "your_pass") + t.Setenv("SNOWFLAKE_DATABASE", "your_db") + t.Setenv("SNOWFLAKE_SCHEMA", "your_schema") + t.Setenv("SNOWFLAKE_WAREHOUSE", "your_wh") + t.Setenv("SNOWFLAKE_ROLE", "your_role") + + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + name string + in []byte + wantToolset server.ToolsetConfigs + }{ + { + name: "alloydb omni prebuilt tools", + in: alloydb_omni_config, + wantToolset: server.ToolsetConfigs{ + "alloydb_omni_database_tools": tools.ToolsetConfig{ + Name: "alloydb_omni_database_tools", + ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_columnar_configurations", "list_columnar_recommended_columns", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"}, + }, + }, + }, + { + name: "alloydb postgres admin prebuilt tools", + in: alloydb_admin_config, + wantToolset: server.ToolsetConfigs{ + "alloydb_postgres_admin_tools": tools.ToolsetConfig{ + Name: "alloydb_postgres_admin_tools", + ToolNames: []string{"create_cluster", "wait_for_operation", "create_instance", "list_clusters", "list_instances", "list_users", "create_user", "get_cluster", "get_instance", "get_user"}, + }, + }, + }, + { + name: "cloudsql pg admin prebuilt tools", + in: cloudsqlpg_admin_config, + wantToolset: server.ToolsetConfigs{ + "cloud_sql_postgres_admin_tools": tools.ToolsetConfig{ + Name: "cloud_sql_postgres_admin_tools", + ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance", "create_backup", "restore_backup"}, + }, + }, + }, + { + name: "cloudsql mysql admin prebuilt tools", + in: cloudsqlmysql_admin_config, + wantToolset: server.ToolsetConfigs{ + "cloud_sql_mysql_admin_tools": tools.ToolsetConfig{ + Name: "cloud_sql_mysql_admin_tools", + ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup", "restore_backup"}, + }, + }, + }, + { + name: "cloudsql mssql admin prebuilt tools", + in: cloudsqlmssql_admin_config, + wantToolset: server.ToolsetConfigs{ + "cloud_sql_mssql_admin_tools": tools.ToolsetConfig{ + Name: "cloud_sql_mssql_admin_tools", + ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup", "restore_backup"}, + }, + }, + }, + { + name: "alloydb prebuilt tools", + in: alloydb_config, + wantToolset: server.ToolsetConfigs{ + "alloydb_postgres_database_tools": tools.ToolsetConfig{ + Name: "alloydb_postgres_database_tools", + ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"}, + }, + }, + }, + { + name: "bigquery prebuilt tools", + in: bigquery_config, + wantToolset: server.ToolsetConfigs{ + "bigquery_database_tools": tools.ToolsetConfig{ + Name: "bigquery_database_tools", + ToolNames: []string{"analyze_contribution", "ask_data_insights", "execute_sql", "forecast", "get_dataset_info", "get_table_info", "list_dataset_ids", "list_table_ids", "search_catalog"}, + }, + }, + }, + { + name: "clickhouse prebuilt tools", + in: clickhouse_config, + wantToolset: server.ToolsetConfigs{ + "clickhouse_database_tools": tools.ToolsetConfig{ + Name: "clickhouse_database_tools", + ToolNames: []string{"execute_sql", "list_databases", "list_tables"}, + }, + }, + }, + { + name: "cloudsqlpg prebuilt tools", + in: cloudsqlpg_config, + wantToolset: server.ToolsetConfigs{ + "cloud_sql_postgres_database_tools": tools.ToolsetConfig{ + Name: "cloud_sql_postgres_database_tools", + ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"}, + }, + }, + }, + { + name: "cloudsqlmysql prebuilt tools", + in: cloudsqlmysql_config, + wantToolset: server.ToolsetConfigs{ + "cloud_sql_mysql_database_tools": tools.ToolsetConfig{ + Name: "cloud_sql_mysql_database_tools", + ToolNames: []string{"execute_sql", "list_tables", "get_query_plan", "list_active_queries", "list_tables_missing_unique_indexes", "list_table_fragmentation"}, + }, + }, + }, + { + name: "cloudsqlmssql prebuilt tools", + in: cloudsqlmssql_config, + wantToolset: server.ToolsetConfigs{ + "cloud_sql_mssql_database_tools": tools.ToolsetConfig{ + Name: "cloud_sql_mssql_database_tools", + ToolNames: []string{"execute_sql", "list_tables"}, + }, + }, + }, + { + name: "dataplex prebuilt tools", + in: dataplex_config, + wantToolset: server.ToolsetConfigs{ + "dataplex_tools": tools.ToolsetConfig{ + Name: "dataplex_tools", + ToolNames: []string{"search_entries", "lookup_entry", "search_aspect_types"}, + }, + }, + }, + { + name: "serverless spark prebuilt tools", + in: serverless_spark_config, + wantToolset: server.ToolsetConfigs{ + "serverless_spark_tools": tools.ToolsetConfig{ + Name: "serverless_spark_tools", + ToolNames: []string{"list_batches", "get_batch", "cancel_batch", "create_pyspark_batch", "create_spark_batch"}, + }, + }, + }, + { + name: "firestore prebuilt tools", + in: firestoreconfig, + wantToolset: server.ToolsetConfigs{ + "firestore_database_tools": tools.ToolsetConfig{ + Name: "firestore_database_tools", + ToolNames: []string{"get_documents", "add_documents", "update_document", "list_collections", "delete_documents", "query_collection", "get_rules", "validate_rules"}, + }, + }, + }, + { + name: "mysql prebuilt tools", + in: mysql_config, + wantToolset: server.ToolsetConfigs{ + "mysql_database_tools": tools.ToolsetConfig{ + Name: "mysql_database_tools", + ToolNames: []string{"execute_sql", "list_tables", "get_query_plan", "list_active_queries", "list_tables_missing_unique_indexes", "list_table_fragmentation"}, + }, + }, + }, + { + name: "mssql prebuilt tools", + in: mssql_config, + wantToolset: server.ToolsetConfigs{ + "mssql_database_tools": tools.ToolsetConfig{ + Name: "mssql_database_tools", + ToolNames: []string{"execute_sql", "list_tables"}, + }, + }, + }, + { + name: "looker prebuilt tools", + in: looker_config, + wantToolset: server.ToolsetConfigs{ + "looker_tools": tools.ToolsetConfig{ + Name: "looker_tools", + ToolNames: []string{"get_models", "get_explores", "get_dimensions", "get_measures", "get_filters", "get_parameters", "query", "query_sql", "query_url", "get_looks", "run_look", "make_look", "get_dashboards", "run_dashboard", "make_dashboard", "add_dashboard_element", "add_dashboard_filter", "generate_embed_url", "health_pulse", "health_analyze", "health_vacuum", "dev_mode", "get_projects", "get_project_files", "get_project_file", "create_project_file", "update_project_file", "delete_project_file", "validate_project", "get_connections", "get_connection_schemas", "get_connection_databases", "get_connection_tables", "get_connection_table_columns"}, + }, + }, + }, + { + name: "looker-conversational-analytics prebuilt tools", + in: lookerca_config, + wantToolset: server.ToolsetConfigs{ + "looker_conversational_analytics_tools": tools.ToolsetConfig{ + Name: "looker_conversational_analytics_tools", + ToolNames: []string{"ask_data_insights", "get_models", "get_explores"}, + }, + }, + }, + { + name: "postgres prebuilt tools", + in: postgresconfig, + wantToolset: server.ToolsetConfigs{ + "postgres_database_tools": tools.ToolsetConfig{ + Name: "postgres_database_tools", + ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"}, + }, + }, + }, + { + name: "spanner prebuilt tools", + in: spanner_config, + wantToolset: server.ToolsetConfigs{ + "spanner-database-tools": tools.ToolsetConfig{ + Name: "spanner-database-tools", + ToolNames: []string{"execute_sql", "execute_sql_dql", "list_tables", "list_graphs"}, + }, + }, + }, + { + name: "spanner pg prebuilt tools", + in: spannerpg_config, + wantToolset: server.ToolsetConfigs{ + "spanner_postgres_database_tools": tools.ToolsetConfig{ + Name: "spanner_postgres_database_tools", + ToolNames: []string{"execute_sql", "execute_sql_dql", "list_tables"}, + }, + }, + }, + { + name: "mindsdb prebuilt tools", + in: mindsdb_config, + wantToolset: server.ToolsetConfigs{ + "mindsdb-tools": tools.ToolsetConfig{ + Name: "mindsdb-tools", + ToolNames: []string{"mindsdb-execute-sql", "mindsdb-sql"}, + }, + }, + }, + { + name: "sqlite prebuilt tools", + in: sqlite_config, + wantToolset: server.ToolsetConfigs{ + "sqlite_database_tools": tools.ToolsetConfig{ + Name: "sqlite_database_tools", + ToolNames: []string{"execute_sql", "list_tables"}, + }, + }, + }, + { + name: "neo4j prebuilt tools", + in: neo4jconfig, + wantToolset: server.ToolsetConfigs{ + "neo4j_database_tools": tools.ToolsetConfig{ + Name: "neo4j_database_tools", + ToolNames: []string{"execute_cypher", "get_schema"}, + }, + }, + }, + { + name: "alloydb postgres observability prebuilt tools", + in: alloydbobsvconfig, + wantToolset: server.ToolsetConfigs{ + "alloydb_postgres_cloud_monitoring_tools": tools.ToolsetConfig{ + Name: "alloydb_postgres_cloud_monitoring_tools", + ToolNames: []string{"get_system_metrics", "get_query_metrics"}, + }, + }, + }, + { + name: "cloudsql postgres observability prebuilt tools", + in: cloudsqlpgobsvconfig, + wantToolset: server.ToolsetConfigs{ + "cloud_sql_postgres_cloud_monitoring_tools": tools.ToolsetConfig{ + Name: "cloud_sql_postgres_cloud_monitoring_tools", + ToolNames: []string{"get_system_metrics", "get_query_metrics"}, + }, + }, + }, + { + name: "cloudsql mysql observability prebuilt tools", + in: cloudsqlmysqlobsvconfig, + wantToolset: server.ToolsetConfigs{ + "cloud_sql_mysql_cloud_monitoring_tools": tools.ToolsetConfig{ + Name: "cloud_sql_mysql_cloud_monitoring_tools", + ToolNames: []string{"get_system_metrics", "get_query_metrics"}, + }, + }, + }, + { + name: "cloudsql mssql observability prebuilt tools", + in: cloudsqlmssqlobsvconfig, + wantToolset: server.ToolsetConfigs{ + "cloud_sql_mssql_cloud_monitoring_tools": tools.ToolsetConfig{ + Name: "cloud_sql_mssql_cloud_monitoring_tools", + ToolNames: []string{"get_system_metrics"}, + }, + }, + }, + { + name: "cloud healthcare prebuilt tools", + in: cloudhealthcare_config, + wantToolset: server.ToolsetConfigs{ + "cloud_healthcare_dataset_tools": tools.ToolsetConfig{ + Name: "cloud_healthcare_dataset_tools", + ToolNames: []string{"get_dataset", "list_dicom_stores", "list_fhir_stores"}, + }, + "cloud_healthcare_fhir_tools": tools.ToolsetConfig{ + Name: "cloud_healthcare_fhir_tools", + ToolNames: []string{"get_fhir_store", "get_fhir_store_metrics", "get_fhir_resource", "fhir_patient_search", "fhir_patient_everything", "fhir_fetch_page"}, + }, + "cloud_healthcare_dicom_tools": tools.ToolsetConfig{ + Name: "cloud_healthcare_dicom_tools", + ToolNames: []string{"get_dicom_store", "get_dicom_store_metrics", "search_dicom_studies", "search_dicom_series", "search_dicom_instances", "retrieve_rendered_dicom_instance"}, + }, + }, + }, + { + name: "Snowflake prebuilt tool", + in: snowflake_config, + wantToolset: server.ToolsetConfigs{ + "snowflake_tools": tools.ToolsetConfig{ + Name: "snowflake_tools", + ToolNames: []string{"execute_sql", "list_tables"}, + }, + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + toolsFile, err := parseToolsFile(ctx, tc.in) + if err != nil { + t.Fatalf("failed to parse input: %v", err) + } + if diff := cmp.Diff(tc.wantToolset, toolsFile.Toolsets); diff != "" { + t.Fatalf("incorrect tools parse: diff %v", diff) + } + // Prebuilt configs do not have prompts, so assert empty maps. + if len(toolsFile.Prompts) != 0 { + t.Fatalf("expected empty prompts map for prebuilt config, got: %v", toolsFile.Prompts) + } + }) + } +} + +func TestMergeToolsFiles(t *testing.T) { + file1 := ToolsFile{ + Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}}, + Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}}, + Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}}, + EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}}, + } + file2 := ToolsFile{ + AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}}, + Tools: server.ToolConfigs{"tool2": http.Config{Name: "tool2"}}, + Toolsets: server.ToolsetConfigs{"set2": tools.ToolsetConfig{Name: "set2"}}, + } + fileWithConflicts := ToolsFile{ + Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}}, + Tools: server.ToolConfigs{"tool2": http.Config{Name: "tool2"}}, + } + + testCases := []struct { + name string + files []ToolsFile + want ToolsFile + wantErr bool + }{ + { + name: "merge two distinct files", + files: []ToolsFile{file1, file2}, + want: ToolsFile{ + Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}}, + AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}}, + Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}, "tool2": http.Config{Name: "tool2"}}, + Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}, "set2": tools.ToolsetConfig{Name: "set2"}}, + Prompts: server.PromptConfigs{}, + EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}}, + }, + wantErr: false, + }, + { + name: "merge with conflicts", + files: []ToolsFile{file1, file2, fileWithConflicts}, + wantErr: true, + }, + { + name: "merge single file", + files: []ToolsFile{file1}, + want: ToolsFile{ + Sources: file1.Sources, + AuthServices: make(server.AuthServiceConfigs), + EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}}, + Tools: file1.Tools, + Toolsets: file1.Toolsets, + Prompts: server.PromptConfigs{}, + }, + }, + { + name: "merge empty list", + files: []ToolsFile{}, + want: ToolsFile{ + Sources: make(server.SourceConfigs), + AuthServices: make(server.AuthServiceConfigs), + EmbeddingModels: make(server.EmbeddingModelConfigs), + Tools: make(server.ToolConfigs), + Toolsets: make(server.ToolsetConfigs), + Prompts: server.PromptConfigs{}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := mergeToolsFiles(tc.files...) + if (err != nil) != tc.wantErr { + t.Fatalf("mergeToolsFiles() error = %v, wantErr %v", err, tc.wantErr) + } + if !tc.wantErr { + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("mergeToolsFiles() mismatch (-want +got):\n%s", diff) + } + } else { + if err == nil { + t.Fatal("expected an error for conflicting files but got none") + } + if !strings.Contains(err.Error(), "resource conflicts detected") { + t.Errorf("expected conflict error, but got: %v", err) + } + } + }) + } +} + +func TestParameterReferenceValidation(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + // Base template + baseYaml := ` +sources: + dummy-source: + kind: http + baseUrl: http://example.com +tools: + test-tool: + kind: postgres-sql + source: dummy-source + description: test tool + statement: SELECT 1; + parameters: +%s` + + tcs := []struct { + desc string + params string + wantErr bool + errSubstr string + }{ + { + desc: "valid backward reference", + params: ` + - name: source_param + type: string + description: source + - name: copy_param + type: string + description: copy + valueFromParam: source_param`, + wantErr: false, + }, + { + desc: "valid forward reference (out of order)", + params: ` + - name: copy_param + type: string + description: copy + valueFromParam: source_param + - name: source_param + type: string + description: source`, + wantErr: false, + }, + { + desc: "invalid missing reference", + params: ` + - name: copy_param + type: string + description: copy + valueFromParam: non_existent_param`, + wantErr: true, + errSubstr: "references '\"non_existent_param\"' in the 'valueFromParam' field", + }, + { + desc: "invalid self reference", + params: ` + - name: myself + type: string + description: self + valueFromParam: myself`, + wantErr: true, + errSubstr: "parameter \"myself\" cannot copy value from itself", + }, + { + desc: "multiple valid references", + params: ` + - name: a + type: string + description: a + - name: b + type: string + description: b + valueFromParam: a + - name: c + type: string + description: c + valueFromParam: a`, + wantErr: false, + }, + } + + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Indent parameters to match YAML structure + yamlContent := fmt.Sprintf(baseYaml, tc.params) + + _, err := parseToolsFile(ctx, []byte(yamlContent)) + + if tc.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tc.errSubstr) { + t.Errorf("error %q does not contain expected substring %q", err.Error(), tc.errSubstr) + } + } else { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + } + }) + } +} diff --git a/tests/server.go b/tests/server.go index eac693f3a2..d73d780039 100644 --- a/tests/server.go +++ b/tests/server.go @@ -21,8 +21,10 @@ import ( "os" yaml "github.com/goccy/go-yaml" + "github.com/spf13/cobra" "github.com/googleapis/genai-toolbox/cmd" + "github.com/googleapis/genai-toolbox/internal/cli" ) // tmpFileWithCleanup creates a temporary file with the content and returns the path and @@ -50,7 +52,7 @@ func tmpFileWithCleanup(content []byte) (string, func(), error) { type CmdExec struct { Out io.ReadCloser - cmd *cmd.Command + cmd *cobra.Command cancel context.CancelFunc closers []io.Closer done chan bool // closed once the cmd is completed @@ -77,7 +79,8 @@ func StartCmd(ctx context.Context, toolsFile map[string]any, args ...string) (*C return nil, nil, fmt.Errorf("unable to open stdout pipe: %w", err) } - c := cmd.NewCommand(cmd.WithStreams(pw, pw)) + opts := cli.NewToolboxOptions(cli.WithIOStreams(pw, pw)) + c := cmd.NewCommand(opts) c.SetArgs(args) t := &CmdExec{