diff --git a/.ci/continuous.release.cloudbuild.yaml b/.ci/continuous.release.cloudbuild.yaml index b73000aa1b..0025d46719 100644 --- a/.ci/continuous.release.cloudbuild.yaml +++ b/.ci/continuous.release.cloudbuild.yaml @@ -305,4 +305,4 @@ substitutions: _AR_HOSTNAME: ${_REGION}-docker.pkg.dev _AR_REPO_NAME: toolbox-dev _BUCKET_NAME: genai-toolbox-dev - _DOCKER_URI: ${_AR_HOSTNAME}/${PROJECT_ID}/${_AR_REPO_NAME}/toolbox + _DOCKER_URI: ${_AR_HOSTNAME}/${PROJECT_ID}/${_AR_REPO_NAME}/toolbox \ No newline at end of file diff --git a/.ci/integration.cloudbuild.yaml b/.ci/integration.cloudbuild.yaml index 08069d385a..14742514bc 100644 --- a/.ci/integration.cloudbuild.yaml +++ b/.ci/integration.cloudbuild.yaml @@ -212,6 +212,26 @@ steps: bigquery \ bigquery + - id: "cloud-gda" + name: golang:1 + waitFor: ["compile-test-binary"] + entrypoint: /bin/bash + env: + - "GOPATH=/gopath" + - "CLOUD_GDA_PROJECT=$PROJECT_ID" + - "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL" + secretEnv: ["CLIENT_ID"] + volumes: + - name: "go" + path: "/gopath" + args: + - -c + - | + .ci/test_with_coverage.sh \ + "Cloud Gemini Data Analytics" \ + cloudgda \ + cloudgda + - id: "dataplex" name: golang:1 waitFor: ["compile-test-binary"] @@ -318,7 +338,7 @@ steps: .ci/test_with_coverage.sh \ "Spanner" \ spanner \ - spanner + spanner || echo "Integration tests failed." # ignore test failures - id: "neo4j" name: golang:1 @@ -826,8 +846,8 @@ steps: cassandra - id: "oracle" - name: golang:1 - waitFor: ["compile-test-binary"] + name: ghcr.io/oracle/oraclelinux9-instantclient:23 + waitFor: ["install-dependencies"] entrypoint: /bin/bash env: - "GOPATH=/gopath" @@ -840,10 +860,25 @@ steps: args: - -c - | - .ci/test_with_coverage.sh \ - "Oracle" \ - oracle \ - oracle + # Install the C compiler and Oracle SDK headers needed for cgo + dnf install -y gcc oracle-instantclient-devel + # Install Go + curl -L -o go.tar.gz "https://go.dev/dl/go1.25.1.linux-amd64.tar.gz" + tar -C /usr/local -xzf go.tar.gz + export PATH="/usr/local/go/bin:$$PATH" + + go test -v ./internal/sources/oracle/... \ + -coverprofile=oracle_coverage.out \ + -coverpkg=./internal/sources/oracle/...,./internal/tools/oracle/... + + # Coverage check + total_coverage=$(go tool cover -func=oracle_coverage.out | grep "total:" | awk '{print $3}') + echo "Oracle total coverage: $total_coverage" + coverage_numeric=$(echo "$total_coverage" | sed 's/%//') + if awk -v cov="$coverage_numeric" 'BEGIN {exit !(cov < 30)}'; then + echo "Coverage failure: $total_coverage is below 30%." + exit 1 + fi - id: "serverless-spark" name: golang:1 diff --git a/.github/renovate.json5 b/.github/renovate.json5 index 042ea65777..0fb0447b02 100644 --- a/.github/renovate.json5 +++ b/.github/renovate.json5 @@ -24,5 +24,23 @@ ], pinDigests: true, }, + { + groupName: 'Go', + matchManagers: [ + 'gomod', + ], + }, + { + groupName: 'Node', + matchManagers: [ + 'npm', + ], + }, + { + groupName: 'Pip', + matchManagers: [ + 'pip_requirements', + ], + }, ], } diff --git a/.hugo/hugo.toml b/.hugo/hugo.toml index e3b996b7ca..27c2945a6e 100644 --- a/.hugo/hugo.toml +++ b/.hugo/hugo.toml @@ -51,6 +51,10 @@ ignoreFiles = ["quickstart/shared", "quickstart/python", "quickstart/js", "quick # Add a new version block here before every release # The order of versions in this file is mirrored into the dropdown +[[params.versions]] + version = "v0.24.0" + url = "https://googleapis.github.io/genai-toolbox/v0.24.0/" + [[params.versions]] version = "v0.23.0" url = "https://googleapis.github.io/genai-toolbox/v0.23.0/" diff --git a/CHANGELOG.md b/CHANGELOG.md index b2fcb64776..c4fccb78d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,22 @@ # Changelog +## [0.24.0](https://github.com/googleapis/genai-toolbox/compare/v0.23.0...v0.24.0) (2025-12-19) + + +### Features + +* **sources/cloud-gemini-data-analytics:** Add the Gemini Data Analytics (GDA) integration for DB NL2SQL conversion to Toolbox ([#2181](https://github.com/googleapis/genai-toolbox/issues/2181)) ([aa270b2](https://github.com/googleapis/genai-toolbox/commit/aa270b2630da2e3d618db804ca95550445367dbc)) +* **source/cloudsqlmysql:** Add support for IAM authentication in Cloud SQL MySQL source ([#2050](https://github.com/googleapis/genai-toolbox/issues/2050)) ([af3d3c5](https://github.com/googleapis/genai-toolbox/commit/af3d3c52044bea17781b89ce4ab71ff0f874ac20)) +* **sources/oracle:** Add Oracle OCI and Wallet support ([#1945](https://github.com/googleapis/genai-toolbox/issues/1945)) ([8ea39ec](https://github.com/googleapis/genai-toolbox/commit/8ea39ec32fbbaa97939c626fec8c5d86040ed464)) +* Support combining prebuilt and custom tool configurations ([#2188](https://github.com/googleapis/genai-toolbox/issues/2188)) ([5788605](https://github.com/googleapis/genai-toolbox/commit/57886058188aa5d2a51d5846a98bc6d8a650edd1)) +* **tools/mysql-get-query-plan:** Add new `mysql-get-query-plan` tool for MySQL source ([#2123](https://github.com/googleapis/genai-toolbox/issues/2123)) ([0641da0](https://github.com/googleapis/genai-toolbox/commit/0641da0353857317113b2169e547ca69603ddfde)) + + +### Bug Fixes + +* **spanner:** Move list graphs validation to runtime ([#2154](https://github.com/googleapis/genai-toolbox/issues/2154)) ([914b3ee](https://github.com/googleapis/genai-toolbox/commit/914b3eefda40a650efe552d245369e007277dab5)) + + ## [0.23.0](https://github.com/googleapis/genai-toolbox/compare/v0.22.0...v0.23.0) (2025-12-11) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bf7dc9abdb..5e7b8122a9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -167,15 +167,15 @@ tools. [integration.cloudbuild.yaml](.ci/integration.cloudbuild.yaml). [tool-get]: - https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/tests/tool.go#L31 + https://github.com/googleapis/genai-toolbox/blob/v0.23.0/tests/tool.go#L41 [tool-call]: - + https://github.com/googleapis/genai-toolbox/blob/v0.23.0/tests/tool.go#L229 [mcp-call]: - https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/tests/tool.go#L554 + https://github.com/googleapis/genai-toolbox/blob/v0.23.0/tests/tool.go#L789 [execute-sql]: - + https://github.com/googleapis/genai-toolbox/blob/v0.23.0/tests/tool.go#L609 [temp-param]: - + https://github.com/googleapis/genai-toolbox/blob/v0.23.0/tests/tool.go#L454 [temp-param-doc]: https://googleapis.github.io/genai-toolbox/resources/tools/#template-parameters diff --git a/README.md b/README.md index 3bfbf7d5ba..172a1a6f12 100644 --- a/README.md +++ b/README.md @@ -140,7 +140,7 @@ To install Toolbox as a binary: > > ```sh > # see releases page for other versions -> export VERSION=0.23.0 +> export VERSION=0.24.0 > curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox > chmod +x toolbox > ``` @@ -153,7 +153,7 @@ To install Toolbox as a binary: > > ```sh > # see releases page for other versions -> export VERSION=0.23.0 +> export VERSION=0.24.0 > curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/arm64/toolbox > chmod +x toolbox > ``` @@ -166,7 +166,7 @@ To install Toolbox as a binary: > > ```sh > # see releases page for other versions -> export VERSION=0.23.0 +> export VERSION=0.24.0 > curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/amd64/toolbox > chmod +x toolbox > ``` @@ -179,7 +179,7 @@ To install Toolbox as a binary: > > ```cmd > :: see releases page for other versions -> set VERSION=0.23.0 +> set VERSION=0.24.0 > curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe" > ``` > @@ -191,7 +191,7 @@ To install Toolbox as a binary: > > ```powershell > # see releases page for other versions -> $VERSION = "0.23.0" +> $VERSION = "0.24.0" > curl.exe -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v$VERSION/windows/amd64/toolbox.exe" > ``` > @@ -204,7 +204,7 @@ You can also install Toolbox as a container: ```sh # see releases page for other versions -export VERSION=0.23.0 +export VERSION=0.24.0 docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION ``` @@ -228,7 +228,7 @@ To install from source, ensure you have the latest version of [Go installed](https://go.dev/doc/install), and then run the following command: ```sh -go install github.com/googleapis/genai-toolbox@v0.23.0 +go install github.com/googleapis/genai-toolbox@v0.24.0 ``` diff --git a/cmd/root.go b/cmd/root.go index f980ddaea0..e0bb46c642 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -73,6 +73,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselistdatabases" _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselisttables" _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhousesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudgda" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch" @@ -168,6 +169,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqllisttables" _ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlsql" _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlgetqueryplan" _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllistactivequeries" _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttablefragmentation" _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttables" @@ -233,6 +235,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/sources/bigtable" _ "github.com/googleapis/genai-toolbox/internal/sources/cassandra" _ "github.com/googleapis/genai-toolbox/internal/sources/clickhouse" + _ "github.com/googleapis/genai-toolbox/internal/sources/cloudgda" _ "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare" _ "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring" _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" @@ -353,12 +356,12 @@ func NewCommand(opts ...Option) *Command { flags.StringVarP(&cmd.cfg.Address, "address", "a", "127.0.0.1", "Address of the interface the server will listen on.") flags.IntVarP(&cmd.cfg.Port, "port", "p", 5000, "Port the server will listen on.") - flags.StringVar(&cmd.tools_file, "tools_file", "", "File path specifying the tool configuration. Cannot be used with --prebuilt.") + flags.StringVar(&cmd.tools_file, "tools_file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.") // deprecate tools_file _ = flags.MarkDeprecated("tools_file", "please use --tools-file instead") - flags.StringVar(&cmd.tools_file, "tools-file", "", "File path specifying the tool configuration. Cannot be used with --prebuilt, --tools-files, or --tools-folder.") - flags.StringSliceVar(&cmd.tools_files, "tools-files", []string{}, "Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --prebuilt, --tools-file, or --tools-folder.") - flags.StringVar(&cmd.tools_folder, "tools-folder", "", "Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --prebuilt, --tools-file, or --tools-files.") + flags.StringVar(&cmd.tools_file, "tools-file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.") + flags.StringSliceVar(&cmd.tools_files, "tools-files", []string{}, "Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --tools-file, or --tools-folder.") + flags.StringVar(&cmd.tools_folder, "tools-folder", "", "Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --tools-file, or --tools-files.") flags.Var(&cmd.cfg.LogLevel, "log-level", "Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'.") flags.Var(&cmd.cfg.LoggingFormat, "logging-format", "Specify logging format to use. Allowed: 'standard' or 'JSON'.") flags.BoolVar(&cmd.cfg.TelemetryGCP, "telemetry-gcp", false, "Enable exporting directly to Google Cloud Monitoring.") @@ -366,7 +369,7 @@ func NewCommand(opts ...Option) *Command { flags.StringVar(&cmd.cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.") // Fetch prebuilt tools sources to customize the help description prebuiltHelp := fmt.Sprintf( - "Use a prebuilt tool configuration by source type. Cannot be used with --tools-file. Allowed: '%s'.", + "Use a prebuilt tool configuration by source type. Allowed: '%s'.", strings.Join(prebuiltconfigs.GetPrebuiltSources(), "', '"), ) flags.StringVar(&cmd.prebuiltConfig, "prebuilt", "", prebuiltHelp) @@ -460,6 +463,9 @@ func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) { if _, exists := merged.AuthSources[name]; exists { conflicts = append(conflicts, fmt.Sprintf("authSource '%s' (file #%d)", name, fileIndex+1)) } else { + if merged.AuthSources == nil { + merged.AuthSources = make(server.AuthServiceConfigs) + } merged.AuthSources[name] = authSource } } @@ -836,16 +842,10 @@ func run(cmd *Command) error { } }() - var toolsFile ToolsFile + var allToolsFiles []ToolsFile + // Load Prebuilt Configuration if cmd.prebuiltConfig != "" { - // Make sure --prebuilt and --tools-file/--tools-files/--tools-folder flags are mutually exclusive - if cmd.tools_file != "" || len(cmd.tools_files) > 0 || cmd.tools_folder != "" { - errMsg := fmt.Errorf("--prebuilt and --tools-file/--tools-files/--tools-folder flags cannot be used simultaneously") - cmd.logger.ErrorContext(ctx, errMsg.Error()) - return errMsg - } - // Use prebuilt tools buf, err := prebuiltconfigs.Get(cmd.prebuiltConfig) if err != nil { cmd.logger.ErrorContext(ctx, err.Error()) @@ -856,72 +856,96 @@ func run(cmd *Command) error { // Append prebuilt.source to Version string for the User Agent cmd.cfg.Version += "+prebuilt." + cmd.prebuiltConfig - toolsFile, err = parseToolsFile(ctx, buf) + parsed, err := parseToolsFile(ctx, buf) if err != nil { errMsg := fmt.Errorf("unable to parse prebuilt tool configuration: %w", err) cmd.logger.ErrorContext(ctx, errMsg.Error()) return errMsg } - } else if len(cmd.tools_files) > 0 { - // Make sure --tools-file, --tools-files, and --tools-folder flags are mutually exclusive - if cmd.tools_file != "" || cmd.tools_folder != "" { - errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously") - cmd.logger.ErrorContext(ctx, errMsg.Error()) - return errMsg - } - - // Use multiple tools files - cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging %d tool configuration files", len(cmd.tools_files))) - var err error - toolsFile, err = loadAndMergeToolsFiles(ctx, cmd.tools_files) - if err != nil { - cmd.logger.ErrorContext(ctx, err.Error()) - return err - } - } else if cmd.tools_folder != "" { - // Make sure --tools-folder and other flags are mutually exclusive - if cmd.tools_file != "" || len(cmd.tools_files) > 0 { - errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously") - cmd.logger.ErrorContext(ctx, errMsg.Error()) - return errMsg - } - - // Use tools folder - cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging all YAML files from directory: %s", cmd.tools_folder)) - var err error - toolsFile, err = loadAndMergeToolsFolder(ctx, cmd.tools_folder) - if err != nil { - cmd.logger.ErrorContext(ctx, err.Error()) - return err - } - } else { - // Set default value of tools-file flag to tools.yaml - if cmd.tools_file == "" { - cmd.tools_file = "tools.yaml" - } - - // Read single tool file contents - buf, err := os.ReadFile(cmd.tools_file) - if err != nil { - errMsg := fmt.Errorf("unable to read tool file at %q: %w", cmd.tools_file, err) - cmd.logger.ErrorContext(ctx, errMsg.Error()) - return errMsg - } - - toolsFile, err = parseToolsFile(ctx, buf) - if err != nil { - errMsg := fmt.Errorf("unable to parse tool file at %q: %w", cmd.tools_file, err) - cmd.logger.ErrorContext(ctx, errMsg.Error()) - return errMsg - } + allToolsFiles = append(allToolsFiles, parsed) } - cmd.cfg.SourceConfigs, cmd.cfg.AuthServiceConfigs, cmd.cfg.ToolConfigs, cmd.cfg.ToolsetConfigs, cmd.cfg.PromptConfigs = toolsFile.Sources, toolsFile.AuthServices, toolsFile.Tools, toolsFile.Toolsets, toolsFile.Prompts + // Determine if Custom Files should be loaded + // Check for explicit custom flags + isCustomConfigured := cmd.tools_file != "" || len(cmd.tools_files) > 0 || cmd.tools_folder != "" - authSourceConfigs := toolsFile.AuthSources + // Determine if default 'tools.yaml' should be used (No prebuilt AND No custom flags) + useDefaultToolsFile := cmd.prebuiltConfig == "" && !isCustomConfigured + + if useDefaultToolsFile { + cmd.tools_file = "tools.yaml" + isCustomConfigured = true + } + + // Load Custom Configurations + if isCustomConfigured { + // Enforce exclusivity among custom flags (tools-file vs tools-files vs tools-folder) + if (cmd.tools_file != "" && len(cmd.tools_files) > 0) || + (cmd.tools_file != "" && cmd.tools_folder != "") || + (len(cmd.tools_files) > 0 && cmd.tools_folder != "") { + errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously") + cmd.logger.ErrorContext(ctx, errMsg.Error()) + return errMsg + } + + var customTools ToolsFile + var err error + + if len(cmd.tools_files) > 0 { + // Use tools-files + cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging %d tool configuration files", len(cmd.tools_files))) + customTools, err = loadAndMergeToolsFiles(ctx, cmd.tools_files) + } else if cmd.tools_folder != "" { + // Use tools-folder + cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging all YAML files from directory: %s", cmd.tools_folder)) + customTools, err = loadAndMergeToolsFolder(ctx, cmd.tools_folder) + } else { + // Use single file (tools-file or default `tools.yaml`) + buf, readFileErr := os.ReadFile(cmd.tools_file) + if readFileErr != nil { + errMsg := fmt.Errorf("unable to read tool file at %q: %w", cmd.tools_file, readFileErr) + cmd.logger.ErrorContext(ctx, errMsg.Error()) + return errMsg + } + customTools, err = parseToolsFile(ctx, buf) + if err != nil { + err = fmt.Errorf("unable to parse tool file at %q: %w", cmd.tools_file, err) + } + } + + if err != nil { + cmd.logger.ErrorContext(ctx, err.Error()) + return err + } + allToolsFiles = append(allToolsFiles, customTools) + } + + // Merge Everything + // This will error if custom tools collide with prebuilt tools + finalToolsFile, err := mergeToolsFiles(allToolsFiles...) + if err != nil { + cmd.logger.ErrorContext(ctx, err.Error()) + return err + } + + cmd.cfg.SourceConfigs = finalToolsFile.Sources + cmd.cfg.AuthServiceConfigs = finalToolsFile.AuthServices + cmd.cfg.ToolConfigs = finalToolsFile.Tools + cmd.cfg.ToolsetConfigs = finalToolsFile.Toolsets + cmd.cfg.PromptConfigs = finalToolsFile.Prompts + + authSourceConfigs := finalToolsFile.AuthSources if authSourceConfigs != nil { cmd.logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` instead") - cmd.cfg.AuthServiceConfigs = authSourceConfigs + + for k, v := range authSourceConfigs { + if _, exists := cmd.cfg.AuthServiceConfigs[k]; exists { + errMsg := fmt.Errorf("resource conflict detected: authSource '%s' has the same name as an existing authService. Please rename your authSource", k) + cmd.logger.ErrorContext(ctx, errMsg.Error()) + return errMsg + } + cmd.cfg.AuthServiceConfigs[k] = v + } } instrumentation, err := telemetry.CreateTelemetryInstrumentation(versionString) @@ -972,9 +996,8 @@ func run(cmd *Command) error { }() } - watchDirs, watchedFiles := resolveWatcherInputs(cmd.tools_file, cmd.tools_files, cmd.tools_folder) - - if !cmd.cfg.DisableReload { + if isCustomConfigured && !cmd.cfg.DisableReload { + watchDirs, watchedFiles := resolveWatcherInputs(cmd.tools_file, cmd.tools_files, cmd.tools_folder) // start watching the file(s) or folder for changes to trigger dynamic reloading go watchChanges(ctx, watchDirs, watchedFiles, s) } diff --git a/cmd/root_test.go b/cmd/root_test.go index fc29e0b35d..6036c9c478 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -92,6 +92,21 @@ func invokeCommand(args []string) (*Command, string, error) { return c, buf.String(), err } +// invokeCommandWithContext executes the command with a context and returns the captured output. +func invokeCommandWithContext(ctx context.Context, args []string) (*Command, string, error) { + // Capture output using a buffer + buf := new(bytes.Buffer) + c := NewCommand(WithStreams(buf, buf)) + + c.SetArgs(args) + c.SilenceUsage = true + c.SilenceErrors = true + c.SetContext(ctx) + + err := c.Execute() + return c, buf.String(), err +} + func TestVersion(t *testing.T) { data, err := os.ReadFile("version.txt") if err != nil { @@ -1755,11 +1770,6 @@ func TestMutuallyExclusiveFlags(t *testing.T) { args []string errString string }{ - { - desc: "--prebuilt and --tools-file", - args: []string{"--prebuilt", "alloydb", "--tools-file", "my.yaml"}, - errString: "--prebuilt and --tools-file/--tools-files/--tools-folder flags cannot be used simultaneously", - }, { desc: "--tools-file and --tools-files", args: []string{"--tools-file", "my.yaml", "--tools-files", "a.yaml,b.yaml"}, @@ -1902,3 +1912,228 @@ func TestMergeToolsFiles(t *testing.T) { }) } } +func TestPrebuiltAndCustomTools(t *testing.T) { + t.Setenv("SQLITE_DATABASE", "test.db") + // Setup custom tools file + customContent := ` +tools: + custom_tool: + kind: http + source: my-http + method: GET + path: / + description: "A custom tool for testing" +sources: + my-http: + kind: http + baseUrl: http://example.com +` + customFile := filepath.Join(t.TempDir(), "custom.yaml") + if err := os.WriteFile(customFile, []byte(customContent), 0644); err != nil { + t.Fatal(err) + } + + // Tool Conflict File + // SQLite prebuilt has a tool named 'list_tables' + toolConflictContent := ` +tools: + list_tables: + kind: http + source: my-http + method: GET + path: / + description: "Conflicting tool" +sources: + my-http: + kind: http + baseUrl: http://example.com +` + toolConflictFile := filepath.Join(t.TempDir(), "tool_conflict.yaml") + if err := os.WriteFile(toolConflictFile, []byte(toolConflictContent), 0644); err != nil { + t.Fatal(err) + } + + // Source Conflict File + // SQLite prebuilt has a source named 'sqlite-source' + sourceConflictContent := ` +sources: + sqlite-source: + kind: http + baseUrl: http://example.com +tools: + dummy_tool: + kind: http + source: sqlite-source + method: GET + path: / + description: "Dummy" +` + sourceConflictFile := filepath.Join(t.TempDir(), "source_conflict.yaml") + if err := os.WriteFile(sourceConflictFile, []byte(sourceConflictContent), 0644); err != nil { + t.Fatal(err) + } + + // Toolset Conflict File + // SQLite prebuilt has a toolset named 'sqlite_database_tools' + toolsetConflictContent := ` +sources: + dummy-src: + kind: http + baseUrl: http://example.com +tools: + dummy_tool: + kind: http + source: dummy-src + method: GET + path: / + description: "Dummy" +toolsets: + sqlite_database_tools: + - dummy_tool +` + toolsetConflictFile := filepath.Join(t.TempDir(), "toolset_conflict.yaml") + if err := os.WriteFile(toolsetConflictFile, []byte(toolsetConflictContent), 0644); err != nil { + t.Fatal(err) + } + + //Legacy Auth File + authContent := ` +authSources: + legacy-auth: + kind: google + clientId: "test-client-id" +` + authFile := filepath.Join(t.TempDir(), "auth.yaml") + if err := os.WriteFile(authFile, []byte(authContent), 0644); err != nil { + t.Fatal(err) + } + + testCases := []struct { + desc string + args []string + wantErr bool + errString string + cfgCheck func(server.ServerConfig) error + }{ + { + desc: "success mixed", + args: []string{"--prebuilt", "sqlite", "--tools-file", customFile}, + wantErr: false, + cfgCheck: func(cfg server.ServerConfig) error { + if _, ok := cfg.ToolConfigs["custom_tool"]; !ok { + return fmt.Errorf("custom tool not found") + } + if _, ok := cfg.ToolConfigs["list_tables"]; !ok { + return fmt.Errorf("prebuilt tool 'list_tables' not found") + } + return nil + }, + }, + { + desc: "tool conflict error", + args: []string{"--prebuilt", "sqlite", "--tools-file", toolConflictFile}, + wantErr: true, + errString: "resource conflicts detected", + }, + { + desc: "source conflict error", + args: []string{"--prebuilt", "sqlite", "--tools-file", sourceConflictFile}, + wantErr: true, + errString: "resource conflicts detected", + }, + { + desc: "toolset conflict error", + args: []string{"--prebuilt", "sqlite", "--tools-file", toolsetConflictFile}, + wantErr: true, + errString: "resource conflicts detected", + }, + { + desc: "legacy auth additive", + args: []string{"--prebuilt", "sqlite", "--tools-file", authFile}, + wantErr: false, + cfgCheck: func(cfg server.ServerConfig) error { + if _, ok := cfg.AuthServiceConfigs["legacy-auth"]; !ok { + return fmt.Errorf("legacy auth source not merged into auth services") + } + return nil + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + cmd, output, err := invokeCommandWithContext(ctx, tc.args) + + if tc.wantErr { + if err == nil { + t.Fatalf("expected an error but got none") + } + if !strings.Contains(err.Error(), tc.errString) { + t.Errorf("expected error message to contain %q, but got %q", tc.errString, err.Error()) + } + } else { + if err != nil && err != context.DeadlineExceeded && err != context.Canceled { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(output, "Server ready to serve!") { + t.Errorf("server did not start successfully (no ready message found). Output:\n%s", output) + } + if tc.cfgCheck != nil { + if err := tc.cfgCheck(cmd.cfg); err != nil { + t.Errorf("config check failed: %v", err) + } + } + } + }) + } +} + +func TestDefaultToolsFileBehavior(t *testing.T) { + t.Setenv("SQLITE_DATABASE", "test.db") + testCases := []struct { + desc string + args []string + expectRun bool + errString string + }{ + { + desc: "no flags (defaults to tools.yaml)", + args: []string{}, + expectRun: false, + errString: "tools.yaml", // Expect error because tools.yaml doesn't exist in test env + }, + { + desc: "prebuilt only (skips tools.yaml)", + args: []string{"--prebuilt", "sqlite"}, + expectRun: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + _, output, err := invokeCommandWithContext(ctx, tc.args) + + if tc.expectRun { + if err != nil && err != context.DeadlineExceeded && err != context.Canceled { + t.Fatalf("expected server start, got error: %v", err) + } + // Verify it actually started + if !strings.Contains(output, "Server ready to serve!") { + t.Errorf("server did not start successfully (no ready message found). Output:\n%s", output) + } + } else { + if err == nil { + t.Fatalf("expected error reading default file, got nil") + } + if !strings.Contains(err.Error(), tc.errString) { + t.Errorf("expected error message to contain %q, but got %q", tc.errString, err.Error()) + } + } + }) + } +} diff --git a/cmd/version.txt b/cmd/version.txt index ca222b7cf3..2094a100ca 100644 --- a/cmd/version.txt +++ b/cmd/version.txt @@ -1 +1 @@ -0.23.0 +0.24.0 diff --git a/docs/en/concepts/telemetry/index.md b/docs/en/concepts/telemetry/index.md index 862c3832e2..49b7c9edca 100644 --- a/docs/en/concepts/telemetry/index.md +++ b/docs/en/concepts/telemetry/index.md @@ -183,11 +183,11 @@ Protocol (OTLP). If you would like to use a collector, please refer to this The following flags are used to determine Toolbox's telemetry configuration: -| **flag** | **type** | **description** | -|----------------------------|----------|------------------------------------------------------------------------------------------------------------------| -| `--telemetry-gcp` | bool | Enable exporting directly to Google Cloud Monitoring. Default is `false`. | -| `--telemetry-otlp` | string | Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. ""). | -| `--telemetry-service-name` | string | Sets the value of the `service.name` resource attribute. Default is `toolbox`. | +| **flag** | **type** | **description** | +|----------------------------|----------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `--telemetry-gcp` | bool | Enable exporting directly to Google Cloud Monitoring. Default is `false`. | +| `--telemetry-otlp` | string | Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. "127.0.0.1:4318"). To pass an insecure endpoint here, set environment variable `OTEL_EXPORTER_OTLP_INSECURE=true`. | +| `--telemetry-service-name` | string | Sets the value of the `service.name` resource attribute. Default is `toolbox`. | In addition to the flags noted above, you can also make additional configuration for OpenTelemetry via the [General SDK Configuration][sdk-configuration] through @@ -207,5 +207,5 @@ To enable Google Cloud Exporter: To enable OTLP Exporter, provide Collector endpoint: ```bash -./toolbox --telemetry-otlp="http://127.0.0.1:4553" +./toolbox --telemetry-otlp="127.0.0.1:4553" ``` diff --git a/docs/en/getting-started/colab_quickstart.ipynb b/docs/en/getting-started/colab_quickstart.ipynb index 9f7bbcf747..a2e2f989e0 100644 --- a/docs/en/getting-started/colab_quickstart.ipynb +++ b/docs/en/getting-started/colab_quickstart.ipynb @@ -234,7 +234,7 @@ }, "outputs": [], "source": [ - "version = \"0.23.0\" # x-release-please-version\n", + "version = \"0.24.0\" # x-release-please-version\n", "! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n", "\n", "# Make the binary executable\n", diff --git a/docs/en/getting-started/introduction/_index.md b/docs/en/getting-started/introduction/_index.md index 6206c75c30..f5f7d76836 100644 --- a/docs/en/getting-started/introduction/_index.md +++ b/docs/en/getting-started/introduction/_index.md @@ -103,7 +103,7 @@ To install Toolbox as a binary on Linux (AMD64): ```sh # see releases page for other versions -export VERSION=0.23.0 +export VERSION=0.24.0 curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox chmod +x toolbox ``` @@ -114,7 +114,7 @@ To install Toolbox as a binary on macOS (Apple Silicon): ```sh # see releases page for other versions -export VERSION=0.23.0 +export VERSION=0.24.0 curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/arm64/toolbox chmod +x toolbox ``` @@ -125,7 +125,7 @@ To install Toolbox as a binary on macOS (Intel): ```sh # see releases page for other versions -export VERSION=0.23.0 +export VERSION=0.24.0 curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/amd64/toolbox chmod +x toolbox ``` @@ -136,7 +136,7 @@ To install Toolbox as a binary on Windows (Command Prompt): ```cmd :: see releases page for other versions -set VERSION=0.23.0 +set VERSION=0.24.0 curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe" ``` @@ -146,7 +146,7 @@ To install Toolbox as a binary on Windows (PowerShell): ```powershell # see releases page for other versions -$VERSION = "0.23.0" +$VERSION = "0.24.0" curl.exe -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v$VERSION/windows/amd64/toolbox.exe" ``` @@ -158,7 +158,7 @@ You can also install Toolbox as a container: ```sh # see releases page for other versions -export VERSION=0.23.0 +export VERSION=0.24.0 docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION ``` @@ -177,7 +177,7 @@ To install from source, ensure you have the latest version of [Go installed](https://go.dev/doc/install), and then run the following command: ```sh -go install github.com/googleapis/genai-toolbox@v0.23.0 +go install github.com/googleapis/genai-toolbox@v0.24.0 ``` {{% /tab %}} diff --git a/docs/en/getting-started/mcp_quickstart/_index.md b/docs/en/getting-started/mcp_quickstart/_index.md index 05de3eb9a3..f07528d2bf 100644 --- a/docs/en/getting-started/mcp_quickstart/_index.md +++ b/docs/en/getting-started/mcp_quickstart/_index.md @@ -105,7 +105,7 @@ In this section, we will download Toolbox, configure our tools in a ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox ``` diff --git a/docs/en/getting-started/quickstart/shared/configure_toolbox.md b/docs/en/getting-started/quickstart/shared/configure_toolbox.md index 0bda1034ae..dda247e2ef 100644 --- a/docs/en/getting-started/quickstart/shared/configure_toolbox.md +++ b/docs/en/getting-started/quickstart/shared/configure_toolbox.md @@ -13,7 +13,7 @@ In this section, we will download Toolbox, configure our tools in a ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox ``` diff --git a/docs/en/how-to/connect-ide/looker_mcp.md b/docs/en/how-to/connect-ide/looker_mcp.md index 1037401b1d..c9bb250ffd 100644 --- a/docs/en/how-to/connect-ide/looker_mcp.md +++ b/docs/en/how-to/connect-ide/looker_mcp.md @@ -49,19 +49,19 @@ to expose your developer assistant tools to a Looker instance: {{< tabpane persist=header >}} {{< tab header="linux/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox {{< /tab >}} {{< tab header="darwin/arm64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox {{< /tab >}} {{< tab header="darwin/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox {{< /tab >}} {{< tab header="windows/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe {{< /tab >}} {{< /tabpane >}} diff --git a/docs/en/how-to/connect-ide/mssql_mcp.md b/docs/en/how-to/connect-ide/mssql_mcp.md index c8b6d22520..defb5f0e18 100644 --- a/docs/en/how-to/connect-ide/mssql_mcp.md +++ b/docs/en/how-to/connect-ide/mssql_mcp.md @@ -45,19 +45,19 @@ instance: {{< tabpane persist=header >}} {{< tab header="linux/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox {{< /tab >}} {{< tab header="darwin/arm64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox {{< /tab >}} {{< tab header="darwin/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox {{< /tab >}} {{< tab header="windows/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe {{< /tab >}} {{< /tabpane >}} diff --git a/docs/en/how-to/connect-ide/mysql_mcp.md b/docs/en/how-to/connect-ide/mysql_mcp.md index 99ac4ae4cb..0d8d5a1ba5 100644 --- a/docs/en/how-to/connect-ide/mysql_mcp.md +++ b/docs/en/how-to/connect-ide/mysql_mcp.md @@ -43,19 +43,19 @@ expose your developer assistant tools to a MySQL instance: {{< tabpane persist=header >}} {{< tab header="linux/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox {{< /tab >}} {{< tab header="darwin/arm64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox {{< /tab >}} {{< tab header="darwin/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox {{< /tab >}} {{< tab header="windows/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe {{< /tab >}} {{< /tabpane >}} diff --git a/docs/en/how-to/connect-ide/neo4j_mcp.md b/docs/en/how-to/connect-ide/neo4j_mcp.md index be775c3ae9..56795aef0f 100644 --- a/docs/en/how-to/connect-ide/neo4j_mcp.md +++ b/docs/en/how-to/connect-ide/neo4j_mcp.md @@ -44,19 +44,19 @@ expose your developer assistant tools to a Neo4j instance: {{< tabpane persist=header >}} {{< tab header="linux/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox {{< /tab >}} {{< tab header="darwin/arm64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox {{< /tab >}} {{< tab header="darwin/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox {{< /tab >}} {{< tab header="windows/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe {{< /tab >}} {{< /tabpane >}} diff --git a/docs/en/how-to/connect-ide/postgres_mcp.md b/docs/en/how-to/connect-ide/postgres_mcp.md index e40f437b68..6ec92b948e 100644 --- a/docs/en/how-to/connect-ide/postgres_mcp.md +++ b/docs/en/how-to/connect-ide/postgres_mcp.md @@ -56,19 +56,19 @@ Omni](https://cloud.google.com/alloydb/omni/current/docs/overview). {{< tabpane persist=header >}} {{< tab header="linux/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox {{< /tab >}} {{< tab header="darwin/arm64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox {{< /tab >}} {{< tab header="darwin/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox {{< /tab >}} {{< tab header="windows/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe {{< /tab >}} {{< /tabpane >}} diff --git a/docs/en/how-to/connect-ide/sqlite_mcp.md b/docs/en/how-to/connect-ide/sqlite_mcp.md index c5336281e6..1493a71885 100644 --- a/docs/en/how-to/connect-ide/sqlite_mcp.md +++ b/docs/en/how-to/connect-ide/sqlite_mcp.md @@ -43,19 +43,19 @@ to expose your developer assistant tools to a SQLite instance: {{< tabpane persist=header >}} {{< tab header="linux/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox {{< /tab >}} {{< tab header="darwin/arm64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox {{< /tab >}} {{< tab header="darwin/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox {{< /tab >}} {{< tab header="windows/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe {{< /tab >}} {{< /tabpane >}} diff --git a/docs/en/how-to/export_telemetry.md b/docs/en/how-to/export_telemetry.md index 0265ce27fb..f9d8c88404 100644 --- a/docs/en/how-to/export_telemetry.md +++ b/docs/en/how-to/export_telemetry.md @@ -79,12 +79,16 @@ There are a couple of steps to run and use a Collector. ``` 1. Run toolbox with the `--telemetry-otlp` flag. Configure it to send them to - `http://127.0.0.1:4553` (for HTTP) or the Collector's URL. + `127.0.0.1:4553` (for HTTP) or the Collector's URL. ```bash - ./toolbox --telemetry-otlp=http://127.0.0.1:4553 + ./toolbox --telemetry-otlp=127.0.0.1:4553 ``` + {{< notice tip >}} + To pass an insecure endpoint, set environment variable `OTEL_EXPORTER_OTLP_INSECURE=true`. + {{< /notice >}} + 1. Once telemetry datas are collected, you can view them in your telemetry backend. If you are using GCP exporters, telemetry will be visible in GCP dashboard at [Metrics Explorer][metrics-explorer] and [Trace diff --git a/docs/en/reference/cli.md b/docs/en/reference/cli.md index 490e63fe2a..1c9829995e 100644 --- a/docs/en/reference/cli.md +++ b/docs/en/reference/cli.md @@ -16,14 +16,14 @@ description: > | | `--log-level` | Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'. | `info` | | | `--logging-format` | Specify logging format to use. Allowed: 'standard' or 'JSON'. | `standard` | | `-p` | `--port` | Port the server will listen on. | `5000` | -| | `--prebuilt` | Use a prebuilt tool configuration by source type. Cannot be used with --tools-file. See [Prebuilt Tools Reference](prebuilt-tools.md) for allowed values. | | +| | `--prebuilt` | Use a prebuilt tool configuration by source type. See [Prebuilt Tools Reference](prebuilt-tools.md) for allowed values. | | | | `--stdio` | Listens via MCP STDIO instead of acting as a remote HTTP server. | | | | `--telemetry-gcp` | Enable exporting directly to Google Cloud Monitoring. | | | | `--telemetry-otlp` | Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318') | | | | `--telemetry-service-name` | Sets the value of the service.name resource attribute for telemetry data. | `toolbox` | -| | `--tools-file` | File path specifying the tool configuration. Cannot be used with --prebuilt, --tools-files, or --tools-folder. | | -| | `--tools-files` | Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --prebuilt, --tools-file, or --tools-folder. | | -| | `--tools-folder` | Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --prebuilt, --tools-file, or --tools-files. | | +| | `--tools-file` | File path specifying the tool configuration. Cannot be used with --tools-files or --tools-folder. | | +| | `--tools-files` | Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --tools-file or --tools-folder. | | +| | `--tools-folder` | Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --tools-file or --tools-files. | | | | `--ui` | Launches the Toolbox UI web server. | | | | `--allowed-origins` | Specifies a list of origins permitted to access this server. | `*` | | `-v` | `--version` | version for toolbox | | @@ -46,6 +46,9 @@ description: > ```bash # Basic server with custom port configuration ./toolbox --tools-file "tools.yaml" --port 8080 + +# Server with prebuilt + custom tools configurations +./toolbox --tools-file tools.yaml --prebuilt alloydb-postgres ``` ### Tool Configuration Sources @@ -72,8 +75,8 @@ The CLI supports multiple mutually exclusive ways to specify tool configurations {{< notice tip >}} The CLI enforces mutual exclusivity between configuration source flags, -preventing simultaneous use of `--prebuilt` with file-based options, and -ensuring only one of `--tools-file`, `--tools-files`, or `--tools-folder` is +preventing simultaneous use of the file-based options ensuring only one of +`--tools-file`, `--tools-files`, or `--tools-folder` is used at a time. {{< /notice >}} diff --git a/docs/en/reference/prebuilt-tools.md b/docs/en/reference/prebuilt-tools.md index 7f0ee52821..b340ac055a 100644 --- a/docs/en/reference/prebuilt-tools.md +++ b/docs/en/reference/prebuilt-tools.md @@ -13,6 +13,12 @@ allowing developers to interact with and take action on databases. See guides, [Connect from your IDE](../how-to/connect-ide/_index.md), for details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP. +{{< notice tip >}} +You can now use `--prebuilt` along `--tools-file`, `--tools-files`, or +`--tools-folder` to combine prebuilt configs with custom tools. +See [Usage Examples](../reference/cli.md#examples). +{{< /notice >}} + ## AlloyDB Postgres * `--prebuilt` value: `alloydb-postgres` diff --git a/docs/en/resources/sources/cloud-gda.md b/docs/en/resources/sources/cloud-gda.md new file mode 100644 index 0000000000..dc400f17e8 --- /dev/null +++ b/docs/en/resources/sources/cloud-gda.md @@ -0,0 +1,40 @@ +--- +title: "Gemini Data Analytics" +type: docs +weight: 1 +description: > + A "cloud-gemini-data-analytics" source provides a client for the Gemini Data Analytics API. +aliases: + - /resources/sources/cloud-gemini-data-analytics +--- + +## About + +The `cloud-gemini-data-analytics` source provides a client to interact with the [Gemini Data Analytics API](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/reference/rest). This allows tools to send natural language queries to the API. + +Authentication can be handled in two ways: + +1. **Application Default Credentials (ADC) (Recommended):** By default, the source uses ADC to authenticate with the API. The Toolbox server will fetch the credentials from its running environment (server-side authentication). This is the recommended method. +2. **Client-side OAuth:** If `useClientOAuth` is set to `true`, the source expects the authentication token to be provided by the caller when making a request to the Toolbox server (typically via an HTTP Bearer token). The Toolbox server will then forward this token to the underlying Gemini Data Analytics API calls. + +## Example + +```yaml +sources: + my-gda-source: + kind: cloud-gemini-data-analytics + projectId: my-project-id + + my-oauth-gda-source: + kind: cloud-gemini-data-analytics + projectId: my-project-id + useClientOAuth: true +``` + +## Reference + +| **field** | **type** | **required** | **description** | +| -------------- | :------: | :----------: | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| kind | string | true | Must be "cloud-gemini-data-analytics". | +| projectId | string | true | The Google Cloud Project ID where the API is enabled. | +| useClientOAuth | boolean | false | If true, the source uses the token provided by the caller (forwarded to the API). Otherwise, it uses server-side Application Default Credentials (ADC). Defaults to `false`. | diff --git a/docs/en/resources/sources/cloud-sql-mysql.md b/docs/en/resources/sources/cloud-sql-mysql.md index 188bcbce26..e9f89f22a9 100644 --- a/docs/en/resources/sources/cloud-sql-mysql.md +++ b/docs/en/resources/sources/cloud-sql-mysql.md @@ -31,6 +31,9 @@ to a database by following these instructions][csql-mysql-quickstart]. - [`mysql-list-active-queries`](../tools/mysql/mysql-list-active-queries.md) List active queries in Cloud SQL for MySQL. +- [`mysql-get-query-plan`](../tools/mysql/mysql-get-query-plan.md) + Provide information about how MySQL executes a SQL statement (EXPLAIN). + - [`mysql-list-tables`](../tools/mysql/mysql-list-tables.md) List tables in a Cloud SQL for MySQL database. diff --git a/docs/en/resources/sources/mysql.md b/docs/en/resources/sources/mysql.md index 44d46195ac..95f2b96d7c 100644 --- a/docs/en/resources/sources/mysql.md +++ b/docs/en/resources/sources/mysql.md @@ -25,6 +25,9 @@ reliability, performance, and ease of use. - [`mysql-list-active-queries`](../tools/mysql/mysql-list-active-queries.md) List active queries in MySQL. +- [`mysql-get-query-plan`](../tools/mysql/mysql-get-query-plan.md) + Provide information about how MySQL executes a SQL statement (EXPLAIN). + - [`mysql-list-tables`](../tools/mysql/mysql-list-tables.md) List tables in a MySQL database. diff --git a/docs/en/resources/sources/oracle.md b/docs/en/resources/sources/oracle.md index 4932ea6e22..51fa18fe13 100644 --- a/docs/en/resources/sources/oracle.md +++ b/docs/en/resources/sources/oracle.md @@ -18,10 +18,10 @@ DW) database workloads. ## Available Tools - [`oracle-sql`](../tools/oracle/oracle-sql.md) - Execute pre-defined prepared SQL queries in Oracle. + Execute pre-defined prepared SQL queries in Oracle. - [`oracle-execute-sql`](../tools/oracle/oracle-execute-sql.md) - Run parameterized SQL queries in Oracle. + Run parameterized SQL queries in Oracle. ## Requirements @@ -33,6 +33,25 @@ user][oracle-users] to log in to the database with the necessary permissions. [oracle-users]: https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/CREATE-USER.html +### Oracle Driver Requirement (Conditional) + +The Oracle source offers two connection drivers: + +1. **Pure Go Driver (`useOCI: false`, default):** Uses the `go-ora` library. + This driver is simpler and does not require any local Oracle software + installation, but it **lacks support for advanced features** like Oracle + Wallets or Kerberos authentication. + +2. **OCI-Based Driver (`useOCI: true`):** Uses the `godror` library, which + provides access to **advanced Oracle features** like Digital Wallet support. + +If you set `useOCI: true`, you **must** install the **Oracle Instant Client** +libraries on the machine where this tool runs. + +You can download the Instant Client from the official Oracle website: [Oracle +Instant Client +Downloads](https://www.oracle.com/database/technologies/instant-client/downloads.html) + ## Connection Methods You can configure the connection to your Oracle database using one of the @@ -66,12 +85,15 @@ using a TNS (Transparent Network Substrate) alias. containing it. This setting will override the `TNS_ADMIN` environment variable. -## Example +## Examples + +This example demonstrates the four connection methods you could choose from: ```yaml sources: my-oracle-source: kind: oracle + # --- Choose one connection method --- # 1. Host, Port, and Service Name host: 127.0.0.1 @@ -88,6 +110,43 @@ sources: user: ${USER_NAME} password: ${PASSWORD} + # Optional: Set to true to use the OCI-based driver for advanced features (Requires Oracle Instant Client) +``` + +### Using an Oracle Wallet + +Oracle Wallet allows you to store credentails used for database connection. Depending whether you are using an OCI-based driver, the wallet configuration is different. + +#### Pure Go Driver (`useOCI: false`) - Oracle Wallet + +The `go-ora` driver uses the `walletLocation` field to connect to a database secured with an Oracle Wallet without standard username and password. + +```yaml +sources: + pure-go-wallet: + kind: oracle + connectionString: "127.0.0.1:1521/XEPDB1" + user: ${USER_NAME} + password: ${PASSWORD} + # The TNS Alias is often required to connect to a service registered in tnsnames.ora + tnsAlias: "SECURE_DB_ALIAS" + walletLocation: "/path/to/my/wallet/directory" +``` + +#### OCI-Based Driver (`useOCI: true`) - Oracle Wallet + +For the OCI-based driver, wallet authentication is triggered by setting tnsAdmin to the wallet directory and connecting via a tnsAlias. + +```yaml +sources: + oci-wallet: + kind: oracle + connectionString: "127.0.0.1:1521/XEPDB1" + user: ${USER_NAME} + password: ${PASSWORD} + tnsAlias: "WALLET_DB_ALIAS" + tnsAdmin: "/opt/oracle/wallet" # Directory containing tnsnames.ora, sqlnet.ora, and wallet files + useOCI: true ``` {{< notice tip >}} @@ -97,14 +156,15 @@ instead of hardcoding your secrets into the configuration file. ## Reference -| **field** | **type** | **required** | **description** | -|------------------|:--------:|:------------:|-----------------------------------------------------------------------------------------------------------------------------| -| kind | string | true | Must be "oracle". | -| user | string | true | Name of the Oracle user to connect as (e.g. "my-oracle-user"). | -| password | string | true | Password of the Oracle user (e.g. "my-password"). | -| host | string | false | IP address or hostname to connect to (e.g. "127.0.0.1"). Required if not using `connectionString` or `tnsAlias`. | -| port | integer | false | Port to connect to (e.g. "1521"). Required if not using `connectionString` or `tnsAlias`. | -| serviceName | string | false | The Oracle service name of the database to connect to. Required if not using `connectionString` or `tnsAlias`. | -| connectionString | string | false | A direct connection string (e.g. "hostname:port/servicename"). Use as an alternative to `host`, `port`, and `serviceName`. | -| tnsAlias | string | false | A TNS alias from a `tnsnames.ora` file. Use as an alternative to `host`/`port` or `connectionString`. | -| tnsAdmin | string | false | Path to the directory containing the `tnsnames.ora` file. This overrides the `TNS_ADMIN` environment variable if it is set. | +| **field** | **type** | **required** | **description** | +|------------------|:--------:|:------------:|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| kind | string | true | Must be "oracle". | +| user | string | true | Name of the Oracle user to connect as (e.g. "my-oracle-user"). | +| password | string | true | Password of the Oracle user (e.g. "my-password"). | +| host | string | false | IP address or hostname to connect to (e.g. "127.0.0.1"). Required if not using `connectionString` or `tnsAlias`. | +| port | integer | false | Port to connect to (e.g. "1521"). Required if not using `connectionString` or `tnsAlias`. | +| serviceName | string | false | The Oracle service name of the database to connect to. Required if not using `connectionString` or `tnsAlias`. | +| connectionString | string | false | A direct connection string (e.g. "hostname:port/servicename"). Use as an alternative to `host`, `port`, and `serviceName`. | +| tnsAlias | string | false | A TNS alias from a `tnsnames.ora` file. Use as an alternative to `host`/`port` or `connectionString`. | +| tnsAdmin | string | false | Path to the directory containing the `tnsnames.ora` file. This overrides the `TNS_ADMIN` environment variable if it is set. | +| useOCI | bool | false | If true, uses the OCI-based driver (godror) which supports Oracle Wallet/Kerberos but requires the Oracle Instant Client libraries to be installed. Defaults to false (pure Go driver). | diff --git a/docs/en/resources/tools/cloudgda/_index.md b/docs/en/resources/tools/cloudgda/_index.md new file mode 100644 index 0000000000..63e1189632 --- /dev/null +++ b/docs/en/resources/tools/cloudgda/_index.md @@ -0,0 +1,7 @@ +--- +title: "Gemini Data Analytics" +type: docs +weight: 1 +description: > + Tools for Gemini Data Analytics. +--- diff --git a/docs/en/resources/tools/cloudgda/cloud-gda-query.md b/docs/en/resources/tools/cloudgda/cloud-gda-query.md new file mode 100644 index 0000000000..faf119d6e6 --- /dev/null +++ b/docs/en/resources/tools/cloudgda/cloud-gda-query.md @@ -0,0 +1,92 @@ +--- +title: "Gemini Data Analytics QueryData" +type: docs +weight: 1 +description: > + A tool to convert natural language queries into SQL statements using the Gemini Data Analytics QueryData API. +aliases: + - /resources/tools/cloud-gemini-data-analytics-query +--- + +## About + +The `cloud-gemini-data-analytics-query` tool allows you to send natural language questions to the Gemini Data Analytics API and receive structured responses containing SQL queries, natural language answers, and explanations. For details on defining data agent context for database data sources, see the official [documentation](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/data-agent-authored-context-databases). + +## Example + +```yaml +tools: + my-gda-query-tool: + kind: cloud-gemini-data-analytics-query + source: my-gda-source + description: "Use this tool to send natural language queries to the Gemini Data Analytics API and receive SQL, natural language answers, and explanations." + location: ${your_database_location} + context: + datasourceReferences: + cloudSqlReference: + databaseReference: + projectId: "${your_project_id}" + region: "${your_database_instance_region}" + instanceId: "${your_database_instance_id}" + databaseId: "${your_database_name}" + engine: "POSTGRESQL" + agentContextReference: + contextSetId: "${your_context_set_id}" # E.g. projects/${project_id}/locations/${context_set_location}/contextSets/${context_set_id} + generationOptions: + generateQueryResult: true + generateNaturalLanguageAnswer: true + generateExplanation: true + generateDisambiguationQuestion: true +``` + +### Usage Flow + +When using this tool, a `prompt` parameter containing a natural language query is provided to the tool (typically by an agent). The tool then interacts with the Gemini Data Analytics API using the context defined in your configuration. + +The structure of the response depends on the `generationOptions` configured in your tool definition (e.g., enabling `generateQueryResult` will include the SQL query results). + +See [Data Analytics API REST documentation](https://clouddocs.devsite.corp.google.com/gemini/docs/conversational-analytics-api/reference/rest/v1alpha/projects.locations/queryData?rep_location=global) for details. + +**Example Input Prompt:** + +```text +How many accounts who have region in Prague are eligible for loans? A3 contains the data of region. +``` + +**Example API Response:** + +```json +{ + "generatedQuery": "SELECT COUNT(T1.account_id) FROM account AS T1 INNER JOIN loan AS T2 ON T1.account_id = T2.account_id INNER JOIN district AS T3 ON T1.district_id = T3.district_id WHERE T3.A3 = 'Prague'", + "intentExplanation": "I found a template that matches the user's question. The template asks about the number of accounts who have region in a given city and are eligible for loans. The question asks about the number of accounts who have region in Prague and are eligible for loans. The template's parameterized SQL is 'SELECT COUNT(T1.account_id) FROM account AS T1 INNER JOIN loan AS T2 ON T1.account_id = T2.account_id INNER JOIN district AS T3 ON T1.district_id = T3.district_id WHERE T3.A3 = ?'. I will replace the named parameter '?' with 'Prague'.", + "naturalLanguageAnswer": "There are 84 accounts from the Prague region that are eligible for loans.", + "queryResult": { + "columns": [ + { + "type": "INT64" + } + ], + "rows": [ + { + "values": [ + { + "value": "84" + } + ] + } + ], + "totalRowCount": "1" + } +} +``` + +## Reference + +| **field** | **type** | **required** | **description** | +| ----------------- | :------: | :----------: | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| kind | string | true | Must be "cloud-gemini-data-analytics-query". | +| source | string | true | The name of the `cloud-gemini-data-analytics` source to use. | +| description | string | true | A description of the tool's purpose. | +| location | string | true | The Google Cloud location of the target database resource (e.g., "us-central1"). This is used to construct the parent resource name in the API call. | +| context | object | true | The context for the query, including datasource references. See [QueryDataContext](https://github.com/googleapis/googleapis/blob/b32495a713a68dd0dff90cf0b24021debfca048a/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto#L156) for details. | +| generationOptions | object | false | Options for generating the response. See [GenerationOptions](https://github.com/googleapis/googleapis/blob/b32495a713a68dd0dff90cf0b24021debfca048a/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto#L135) for details. | diff --git a/docs/en/resources/tools/mysql/mysql-get-query-plan.md b/docs/en/resources/tools/mysql/mysql-get-query-plan.md new file mode 100644 index 0000000000..d77b81e097 --- /dev/null +++ b/docs/en/resources/tools/mysql/mysql-get-query-plan.md @@ -0,0 +1,39 @@ +--- +title: "mysql-get-query-plan" +type: docs +weight: 1 +description: > + A "mysql-get-query-plan" tool gets the execution plan for a SQL statement against a MySQL + database. +aliases: +- /resources/tools/mysql-get-query-plan +--- + +## About + +A `mysql-get-query-plan` tool gets the execution plan for a SQL statement against a MySQL +database. It's compatible with any of the following sources: + +- [cloud-sql-mysql](../../sources/cloud-sql-mysql.md) +- [mysql](../../sources/mysql.md) + +`mysql-get-query-plan` takes one input parameter `sql_statement` and gets the execution plan for the SQL +statement against the `source`. + +## Example + +```yaml +tools: + get_query_plan_tool: + kind: mysql-get-query-plan + source: my-mysql-instance + description: Use this tool to get the execution plan for a sql statement. +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------| +| kind | string | true | Must be "mysql-get-query-plan". | +| source | string | true | Name of the source the SQL should execute on. | +| description | string | true | Description of the tool that is passed to the LLM. | diff --git a/docs/en/samples/alloydb/ai-nl/alloydb_ai_nl.ipynb b/docs/en/samples/alloydb/ai-nl/alloydb_ai_nl.ipynb index 330905b66d..fc8e5300b1 100644 --- a/docs/en/samples/alloydb/ai-nl/alloydb_ai_nl.ipynb +++ b/docs/en/samples/alloydb/ai-nl/alloydb_ai_nl.ipynb @@ -771,7 +771,7 @@ }, "outputs": [], "source": [ - "version = \"0.23.0\" # x-release-please-version\n", + "version = \"0.24.0\" # x-release-please-version\n", "! curl -L -o /content/toolbox https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n", "\n", "# Make the binary executable\n", diff --git a/docs/en/samples/alloydb/mcp_quickstart.md b/docs/en/samples/alloydb/mcp_quickstart.md index 3609729d4a..c047416428 100644 --- a/docs/en/samples/alloydb/mcp_quickstart.md +++ b/docs/en/samples/alloydb/mcp_quickstart.md @@ -123,7 +123,7 @@ In this section, we will download and install the Toolbox binary. ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - export VERSION="0.23.0" + export VERSION="0.24.0" curl -O https://storage.googleapis.com/genai-toolbox/v$VERSION/$OS/toolbox ``` diff --git a/docs/en/samples/bigquery/colab_quickstart_bigquery.ipynb b/docs/en/samples/bigquery/colab_quickstart_bigquery.ipynb index 85d447c4a5..eb551ca015 100644 --- a/docs/en/samples/bigquery/colab_quickstart_bigquery.ipynb +++ b/docs/en/samples/bigquery/colab_quickstart_bigquery.ipynb @@ -220,7 +220,7 @@ }, "outputs": [], "source": [ - "version = \"0.23.0\" # x-release-please-version\n", + "version = \"0.24.0\" # x-release-please-version\n", "! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n", "\n", "# Make the binary executable\n", diff --git a/docs/en/samples/bigquery/local_quickstart.md b/docs/en/samples/bigquery/local_quickstart.md index 506232e856..badda3f75e 100644 --- a/docs/en/samples/bigquery/local_quickstart.md +++ b/docs/en/samples/bigquery/local_quickstart.md @@ -179,7 +179,7 @@ to use BigQuery, and then run the Toolbox server. ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox ``` diff --git a/docs/en/samples/bigquery/mcp_quickstart/_index.md b/docs/en/samples/bigquery/mcp_quickstart/_index.md index 2341054e6e..6f0b44d18b 100644 --- a/docs/en/samples/bigquery/mcp_quickstart/_index.md +++ b/docs/en/samples/bigquery/mcp_quickstart/_index.md @@ -98,7 +98,7 @@ In this section, we will download Toolbox, configure our tools in a ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox ``` diff --git a/docs/en/samples/looker/looker_gemini.md b/docs/en/samples/looker/looker_gemini.md index 2d741958cb..0fc81afc32 100644 --- a/docs/en/samples/looker/looker_gemini.md +++ b/docs/en/samples/looker/looker_gemini.md @@ -34,7 +34,7 @@ In this section, we will download Toolbox and run the Toolbox server. ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox ``` diff --git a/docs/en/samples/looker/looker_gemini_oauth/_index.md b/docs/en/samples/looker/looker_gemini_oauth/_index.md index b57a142c62..6eb730ceee 100644 --- a/docs/en/samples/looker/looker_gemini_oauth/_index.md +++ b/docs/en/samples/looker/looker_gemini_oauth/_index.md @@ -48,7 +48,7 @@ In this section, we will download Toolbox and run the Toolbox server. ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox ``` diff --git a/docs/en/samples/looker/looker_mcp_inspector/_index.md b/docs/en/samples/looker/looker_mcp_inspector/_index.md index 985f041a4a..ef3a51c4e9 100644 --- a/docs/en/samples/looker/looker_mcp_inspector/_index.md +++ b/docs/en/samples/looker/looker_mcp_inspector/_index.md @@ -34,7 +34,7 @@ In this section, we will download Toolbox and run the Toolbox server. ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox ``` diff --git a/gemini-extension.json b/gemini-extension.json index 08594982cf..b068279cd6 100644 --- a/gemini-extension.json +++ b/gemini-extension.json @@ -1,6 +1,6 @@ { "name": "mcp-toolbox-for-databases", - "version": "0.23.0", + "version": "0.24.0", "description": "MCP Toolbox for Databases is an open-source MCP server for more than 30 different datasources.", "contextFileName": "MCP-TOOLBOX-EXTENSION.md" } \ No newline at end of file diff --git a/go.mod b/go.mod index ba4f69afbc..e0ed921ac5 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( cloud.google.com/go/dataplex v1.28.0 cloud.google.com/go/dataproc/v2 v2.15.0 cloud.google.com/go/firestore v1.20.0 - cloud.google.com/go/geminidataanalytics v0.2.1 + cloud.google.com/go/geminidataanalytics v0.3.0 cloud.google.com/go/longrunning v0.7.0 cloud.google.com/go/spanner v1.86.1 github.com/ClickHouse/clickhouse-go/v2 v2.40.3 @@ -22,7 +22,7 @@ require ( github.com/cenkalti/backoff/v5 v5.0.3 github.com/couchbase/gocb/v2 v2.11.1 github.com/couchbase/tools-common/http v1.0.9 - github.com/elastic/elastic-transport-go/v8 v8.7.0 + github.com/elastic/elastic-transport-go/v8 v8.8.0 github.com/elastic/go-elasticsearch/v9 v9.2.0 github.com/fsnotify/fsnotify v1.9.0 github.com/go-chi/chi/v5 v5.2.3 @@ -33,6 +33,7 @@ require ( github.com/go-playground/validator/v10 v10.28.0 github.com/go-sql-driver/mysql v1.9.3 github.com/goccy/go-yaml v1.18.0 + github.com/godror/godror v0.49.6 github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.7.6 @@ -41,7 +42,7 @@ require ( github.com/microsoft/go-mssqldb v1.9.3 github.com/nakagami/firebirdsql v0.9.15 github.com/neo4j/neo4j-go-driver/v5 v5.28.4 - github.com/redis/go-redis/v9 v9.16.0 + github.com/redis/go-redis/v9 v9.17.2 github.com/sijms/go-ora/v2 v2.9.0 github.com/spf13/cobra v1.10.1 github.com/thlib/go-timezone-local v0.0.7 @@ -91,6 +92,7 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.29.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.54.0 // indirect github.com/PuerkitoBio/goquery v1.10.3 // indirect + github.com/VictoriaMetrics/easyproto v0.1.4 // indirect github.com/ajg/form v1.5.1 // indirect github.com/apache/arrow/go/v15 v15.0.2 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -107,11 +109,13 @@ require ( github.com/felixge/httpsnoop v1.0.4 // indirect github.com/gabriel-vasile/mimetype v1.4.10 // indirect github.com/go-jose/go-jose/v4 v4.1.2 // indirect + github.com/go-logfmt/logfmt v0.6.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/goccy/go-json v0.10.5 // indirect + github.com/godror/knownpb v0.3.0 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect @@ -181,7 +185,7 @@ require ( golang.org/x/time v0.14.0 // indirect golang.org/x/tools v0.38.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 // indirect google.golang.org/grpc v1.76.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect diff --git a/go.sum b/go.sum index d76e60f469..eeac2b4fd4 100644 --- a/go.sum +++ b/go.sum @@ -311,8 +311,8 @@ cloud.google.com/go/gaming v1.6.0/go.mod h1:YMU1GEvA39Qt3zWGyAVA9bpYz/yAhTvaQ1t2 cloud.google.com/go/gaming v1.7.0/go.mod h1:LrB8U7MHdGgFG851iHAfqUdLcKBdQ55hzXy9xBJz0+w= cloud.google.com/go/gaming v1.8.0/go.mod h1:xAqjS8b7jAVW0KFYeRUxngo9My3f33kFmua++Pi+ggM= cloud.google.com/go/gaming v1.9.0/go.mod h1:Fc7kEmCObylSWLO334NcO+O9QMDyz+TKC4v1D7X+Bc0= -cloud.google.com/go/geminidataanalytics v0.2.1 h1:gtG/9VlUJpL67yukFen/twkAEHliYvW7610Rlnn5rpQ= -cloud.google.com/go/geminidataanalytics v0.2.1/go.mod h1:gIsj/ELDCzVbw24185zwjXgbzYiqdGe7TSSK2HrdtA0= +cloud.google.com/go/geminidataanalytics v0.3.0 h1:2Wi/kqFb5OLuEGH7q+/miE19VTqK1MYHjBEHENap9HI= +cloud.google.com/go/geminidataanalytics v0.3.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg= cloud.google.com/go/gkebackup v0.2.0/go.mod h1:XKvv/4LfG829/B8B7xRkk8zRrOEbKtEam6yNfuQNH60= cloud.google.com/go/gkebackup v0.3.0/go.mod h1:n/E671i1aOQvUxT541aTkCwExO/bTer2HDlj4TsBRAo= cloud.google.com/go/gkebackup v0.4.0/go.mod h1:byAyBGUwYGEEww7xsbnUTBHIYcOPy/PgUWUtOeRm9Vg= @@ -683,6 +683,10 @@ github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8 github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/goquery v1.10.3 h1:pFYcNSqHxBD06Fpj/KsbStFRsgRATgnf3LeXiUkhzPo= github.com/PuerkitoBio/goquery v1.10.3/go.mod h1:tMUX0zDMHXYlAQk6p35XxQMqMweEKB7iK7iLNd4RH4Y= +github.com/UNO-SOFT/zlog v0.8.1 h1:TEFkGJHtUfTRgMkLZiAjLSHALjwSBdw6/zByMC5GJt4= +github.com/UNO-SOFT/zlog v0.8.1/go.mod h1:yqFOjn3OhvJ4j7ArJqQNA+9V+u6t9zSAyIZdWdMweWc= +github.com/VictoriaMetrics/easyproto v0.1.4 h1:r8cNvo8o6sR4QShBXQd1bKw/VVLSQma/V2KhTBPf+Sc= +github.com/VictoriaMetrics/easyproto v0.1.4/go.mod h1:QlGlzaJnDfFd8Lk6Ci/fuLxfTo3/GThPs2KH23mv710= github.com/ahmetb/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:3YVZUqkoev4mL+aCwVOSWV4M7pN+NURHL38Z2zq5JKA= github.com/ahmetb/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:ymXt5bw5uSNu4jveerFxE0vNYxF8ncqbptntMaFMg3k= github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= @@ -818,8 +822,8 @@ github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3 github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/elastic/elastic-transport-go/v8 v8.7.0 h1:OgTneVuXP2uip4BA658Xi6Hfw+PeIOod2rY3GVMGoVE= -github.com/elastic/elastic-transport-go/v8 v8.7.0/go.mod h1:YLHer5cj0csTzNFXoNQ8qhtGY1GTvSqPnKWKaqQE3Hk= +github.com/elastic/elastic-transport-go/v8 v8.8.0 h1:7k1Ua+qluFr6p1jfJjGDl97ssJS/P7cHNInzfxgBQAo= +github.com/elastic/elastic-transport-go/v8 v8.8.0/go.mod h1:YLHer5cj0csTzNFXoNQ8qhtGY1GTvSqPnKWKaqQE3Hk= github.com/elastic/go-elasticsearch/v9 v9.2.0 h1:COeL/g20+ixnUbffe4Wfbu88emrHjAq/LhVfmrjqRQs= github.com/elastic/go-elasticsearch/v9 v9.2.0/go.mod h1:2PB5YQPpY5tWbF65MRqzEXA31PZOdXCkloQSOZtU14I= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -884,6 +888,8 @@ github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vb github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U= github.com/go-latex/latex v0.0.0-20210823091927-c0d11ff05a81/go.mod h1:SX0U8uGpxhq9o2S/CELCSUxEWWAuoCUcVCQWv7G2OCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= +github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -909,6 +915,10 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/godror/godror v0.49.6 h1:ts4ZGw8uLJ42e1D7aXmVuSrld0/lzUzmIUjuUuQOgGM= +github.com/godror/godror v0.49.6/go.mod h1:kTMcxZzRw73RT5kn9v3JkBK4kHI6dqowHotqV72ebU8= +github.com/godror/knownpb v0.3.0 h1:+caUdy8hTtl7X05aPl3tdL540TvCcaQA6woZQroLZMw= +github.com/godror/knownpb v0.3.0/go.mod h1:PpTyfJwiOEAzQl7NtVCM8kdPCnp3uhxsZYIzZ5PV4zU= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= @@ -1172,6 +1182,8 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/neo4j/neo4j-go-driver/v5 v5.28.4 h1:7toxehVcYkZbyxV4W3Ib9VcnyRBQPucF+VwNNmtSXi4= github.com/neo4j/neo4j-go-driver/v5 v5.28.4/go.mod h1:Vff8OwT7QpLm7L2yYr85XNWe9Rbqlbeb9asNXJTHO4k= +github.com/oklog/ulid/v2 v2.0.2 h1:r4fFzBm+bv0wNKNh5eXTwU7i85y5x+uwkxCUTNVQqLc= +github.com/oklog/ulid/v2 v2.0.2/go.mod h1:mtBL0Qe/0HAx6/a4Z30qxVIAL1eQDweXq5lxOEiwQ68= github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8= github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= @@ -1210,8 +1222,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.3.0/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w= -github.com/redis/go-redis/v9 v9.16.0 h1:OotgqgLSRCmzfqChbQyG1PHC3tLNR89DG4jdOERSEP4= -github.com/redis/go-redis/v9 v9.16.0/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= +github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI= +github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= @@ -1671,6 +1683,8 @@ golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= +golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= +golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1990,8 +2004,8 @@ google.golang.org/genproto v0.0.0-20230331144136-dcfb400f0633/go.mod h1:UUQDJDOl google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU= google.golang.org/genproto v0.0.0-20251022142026-3a174f9686a8 h1:a12a2/BiVRxRWIqBbfqoSK6tgq8cyUgMnEI81QlPge0= google.golang.org/genproto v0.0.0-20251022142026-3a174f9686a8/go.mod h1:1Ic78BnpzY8OaTCmzxJDP4qC9INZPbGZl+54RKjtyeI= -google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f h1:OiFuztEyBivVKDvguQJYWq1yDcfAHIID/FVrPR4oiI0= -google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f/go.mod h1:kprOiu9Tr0JYyD6DORrc4Hfyk3RFXqkQ3ctHEum3ZbM= +google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba h1:B14OtaXuMaCQsl2deSvNkyPKIzq3BjfxQp8d00QyWx4= +google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba/go.mod h1:G5IanEx8/PgI9w6CFcYQf7jMtHQhZruvfM1i3qOqk5U= google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 h1:tRPGkdGHuewF4UisLzzHHr1spKw92qLM98nIzxbC0wY= google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= diff --git a/internal/prebuiltconfigs/tools/cloud-sql-mysql.yaml b/internal/prebuiltconfigs/tools/cloud-sql-mysql.yaml index 0a6008eadc..63a73730b7 100644 --- a/internal/prebuiltconfigs/tools/cloud-sql-mysql.yaml +++ b/internal/prebuiltconfigs/tools/cloud-sql-mysql.yaml @@ -32,16 +32,9 @@ tools: source: cloud-sql-mysql-source description: Lists top N (default 10) ongoing queries from processlist and innodb_trx, ordered by execution time in descending order. Returns detailed information of those queries in json format, including process id, query, transaction duration, transaction wait duration, process time, transaction state, process state, username with host, transaction rows locked, transaction rows modified, and db schema. get_query_plan: - kind: mysql-sql + kind: mysql-get-query-plan source: cloud-sql-mysql-source description: "Provide information about how MySQL executes a SQL statement. Common use cases include: 1) analyze query plan to improve its performance, and 2) determine effectiveness of existing indexes and evalueate new ones." - statement: | - EXPLAIN FORMAT=JSON {{.sql_statement}}; - templateParameters: - - name: sql_statement - type: string - description: "the SQL statement to explain" - required: true list_tables: kind: mysql-list-tables source: cloud-sql-mysql-source diff --git a/internal/prebuiltconfigs/tools/mysql.yaml b/internal/prebuiltconfigs/tools/mysql.yaml index 9f85de3642..d3068550eb 100644 --- a/internal/prebuiltconfigs/tools/mysql.yaml +++ b/internal/prebuiltconfigs/tools/mysql.yaml @@ -36,16 +36,9 @@ tools: source: mysql-source description: Lists top N (default 10) ongoing queries from processlist and innodb_trx, ordered by execution time in descending order. Returns detailed information of those queries in json format, including process id, query, transaction duration, transaction wait duration, process time, transaction state, process state, username with host, transaction rows locked, transaction rows modified, and db schema. get_query_plan: - kind: mysql-sql + kind: mysql-get-query-plan source: mysql-source description: "Provide information about how MySQL executes a SQL statement. Common use cases include: 1) analyze query plan to improve its performance, and 2) determine effectiveness of existing indexes and evalueate new ones." - statement: | - EXPLAIN FORMAT=JSON {{.sql_statement}}; - templateParameters: - - name: sql_statement - type: string - description: "the SQL statement to explain" - required: true list_tables: kind: mysql-list-tables source: mysql-source diff --git a/internal/server/api.go b/internal/server/api.go index 5f701baa55..c03a214168 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -172,7 +172,14 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) { accessToken := tools.AccessToken(r.Header.Get("Authorization")) // Check if this specific tool requires the standard authorization header - if tool.RequiresClientAuthorization(s.ResourceMgr) { + clientAuth, err := tool.RequiresClientAuthorization(s.ResourceMgr) + if err != nil { + errMsg := fmt.Errorf("error during invocation: %w", err) + s.logger.DebugContext(ctx, errMsg.Error()) + _ = render.Render(w, r, newErrResponse(errMsg, http.StatusNotFound)) + return + } + if clientAuth { if accessToken == "" { err = fmt.Errorf("tool requires client authorization but access token is missing from the request header") s.logger.DebugContext(ctx, err.Error()) @@ -255,7 +262,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) { } if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden { - if tool.RequiresClientAuthorization(s.ResourceMgr) { + if clientAuth { // Propagate the original 401/403 error. s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err)) _ = render.Render(w, r, newErrResponse(err, statusCode)) diff --git a/internal/server/common_test.go b/internal/server/common_test.go index 4735a560ff..3953e1c7bc 100644 --- a/internal/server/common_test.go +++ b/internal/server/common_test.go @@ -77,9 +77,9 @@ func (t MockTool) Authorized(verifiedAuthServices []string) bool { return !t.unauthorized } -func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) bool { +func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) (bool, error) { // defaulted to false - return t.requiresClientAuthrorization + return t.requiresClientAuthrorization, nil } func (t MockTool) McpManifest() tools.McpManifest { @@ -119,8 +119,8 @@ func (t MockTool) McpManifest() tools.McpManifest { return mcpManifest } -func (t MockTool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t MockTool) GetAuthTokenHeaderName(tools.SourceProvider) (string, error) { + return "Authorization", nil } // MockPrompt is used to mock prompts in tests diff --git a/internal/server/mcp.go b/internal/server/mcp.go index 442369db5c..aecd2454f2 100644 --- a/internal/server/mcp.go +++ b/internal/server/mcp.go @@ -205,10 +205,13 @@ func (s *stdioSession) readLine(ctx context.Context) (string, error) { } // write writes to stdout with response to client -func (s *stdioSession) write(ctx context.Context, response any) error { - res, _ := json.Marshal(response) +func (s *stdioSession) write(_ context.Context, response any) error { + res, err := json.Marshal(response) + if err != nil { + return fmt.Errorf("failed to marshal response to JSON: %w", err) + } - _, err := fmt.Fprintf(s.writer, "%s\n", res) + _, err = fmt.Fprintf(s.writer, "%s\n", res) return err } diff --git a/internal/server/mcp/v20241105/method.go b/internal/server/mcp/v20241105/method.go index 6b2bf223e6..0cbec0d1d2 100644 --- a/internal/server/mcp/v20241105/method.go +++ b/internal/server/mcp/v20241105/method.go @@ -108,10 +108,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re } // Get access token - accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName())) + authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr) + if err != nil { + errMsg := fmt.Errorf("error during invocation: %w", err) + return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg + } + accessToken := tools.AccessToken(header.Get(authTokenHeadername)) // Check if this specific tool requires the standard authorization header - if tool.RequiresClientAuthorization(resourceMgr) { + clientAuth, err := tool.RequiresClientAuthorization(resourceMgr) + if err != nil { + errMsg := fmt.Errorf("error during invocation: %w", err) + return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg + } + if clientAuth { if accessToken == "" { return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized } @@ -183,7 +193,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re } // Upstream auth error if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") { - if tool.RequiresClientAuthorization(resourceMgr) { + if clientAuth { // Error with client credentials should pass down to the client return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err } diff --git a/internal/server/mcp/v20250326/method.go b/internal/server/mcp/v20250326/method.go index c50b1b9636..a51bb161eb 100644 --- a/internal/server/mcp/v20250326/method.go +++ b/internal/server/mcp/v20250326/method.go @@ -108,10 +108,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re } // Get access token - accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName())) + authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr) + if err != nil { + errMsg := fmt.Errorf("error during invocation: %w", err) + return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg + } + accessToken := tools.AccessToken(header.Get(authTokenHeadername)) // Check if this specific tool requires the standard authorization header - if tool.RequiresClientAuthorization(resourceMgr) { + clientAuth, err := tool.RequiresClientAuthorization(resourceMgr) + if err != nil { + errMsg := fmt.Errorf("error during invocation: %w", err) + return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg + } + if clientAuth { if accessToken == "" { return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized } @@ -183,7 +193,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re } // Upstream auth error if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") { - if tool.RequiresClientAuthorization(resourceMgr) { + if clientAuth { // Error with client credentials should pass down to the client return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err } diff --git a/internal/server/mcp/v20250618/method.go b/internal/server/mcp/v20250618/method.go index 183ada0188..ccfa5f102f 100644 --- a/internal/server/mcp/v20250618/method.go +++ b/internal/server/mcp/v20250618/method.go @@ -101,10 +101,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re } // Get access token - accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName())) + authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr) + if err != nil { + errMsg := fmt.Errorf("error during invocation: %w", err) + return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg + } + accessToken := tools.AccessToken(header.Get(authTokenHeadername)) // Check if this specific tool requires the standard authorization header - if tool.RequiresClientAuthorization(resourceMgr) { + clientAuth, err := tool.RequiresClientAuthorization(resourceMgr) + if err != nil { + errMsg := fmt.Errorf("error during invocation: %w", err) + return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg + } + if clientAuth { if accessToken == "" { return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized } @@ -176,7 +186,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re } // Upstream auth error if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") { - if tool.RequiresClientAuthorization(resourceMgr) { + if clientAuth { // Error with client credentials should pass down to the client return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err } diff --git a/internal/sources/alloydbadmin/alloydbadmin.go b/internal/sources/alloydbadmin/alloydbadmin.go index d82126b2ea..f63b12fcd5 100644 --- a/internal/sources/alloydbadmin/alloydbadmin.go +++ b/internal/sources/alloydbadmin/alloydbadmin.go @@ -30,26 +30,6 @@ import ( const SourceKind string = "alloydb-admin" -type userAgentRoundTripper struct { - userAgent string - next http.RoundTripper -} - -func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - newReq := *req - newReq.Header = make(http.Header) - for k, v := range req.Header { - newReq.Header[k] = v - } - ua := newReq.Header.Get("User-Agent") - if ua == "" { - newReq.Header.Set("User-Agent", rt.userAgent) - } else { - newReq.Header.Set("User-Agent", ua+" "+rt.userAgent) - } - return rt.next.RoundTrip(&newReq) -} - // validate interface var _ sources.SourceConfig = Config{} @@ -87,10 +67,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So var client *http.Client if r.UseClientOAuth { client = &http.Client{ - Transport: &userAgentRoundTripper{ - userAgent: ua, - next: http.DefaultTransport, - }, + Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport), } } else { // Use Application Default Credentials @@ -99,10 +76,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("failed to find default credentials: %w", err) } baseClient := oauth2.NewClient(ctx, creds.TokenSource) - baseClient.Transport = &userAgentRoundTripper{ - userAgent: ua, - next: baseClient.Transport, - } + baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport) client = baseClient } @@ -136,6 +110,10 @@ func (s *Source) ToConfig() sources.SourceConfig { return s.Config } +func (s *Source) GetDefaultProject() string { + return s.DefaultProject +} + func (s *Source) GetService(ctx context.Context, accessToken string) (*alloydbrestapi.Service, error) { if s.UseClientOAuth { token := &oauth2.Token{AccessToken: accessToken} diff --git a/internal/sources/cloudgda/cloud_gda.go b/internal/sources/cloudgda/cloud_gda.go new file mode 100644 index 0000000000..a87ff11c59 --- /dev/null +++ b/internal/sources/cloudgda/cloud_gda.go @@ -0,0 +1,133 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package cloudgda + +import ( + "context" + "fmt" + "net/http" + + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/util" + "go.opentelemetry.io/otel/trace" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" +) + +const SourceKind string = "cloud-gemini-data-analytics" +const Endpoint string = "https://geminidataanalytics.googleapis.com" + +// validate interface +var _ sources.SourceConfig = Config{} + +func init() { + if !sources.Register(SourceKind, newConfig) { + panic(fmt.Sprintf("source kind %q already registered", SourceKind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + ProjectID string `yaml:"projectId" validate:"required"` + UseClientOAuth bool `yaml:"useClientOAuth"` +} + +func (r Config) SourceConfigKind() string { + return SourceKind +} + +// Initialize initializes a Gemini Data Analytics Source instance. +func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) { + ua, err := util.UserAgentFromContext(ctx) + if err != nil { + return nil, fmt.Errorf("error in User Agent retrieval: %s", err) + } + + var client *http.Client + if r.UseClientOAuth { + client = &http.Client{ + Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport), + } + } else { + // Use Application Default Credentials + // Scope: "https://www.googleapis.com/auth/cloud-platform" is generally sufficient for GDA + creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform") + if err != nil { + return nil, fmt.Errorf("failed to find default credentials: %w", err) + } + baseClient := oauth2.NewClient(ctx, creds.TokenSource) + baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport) + client = baseClient + } + + s := &Source{ + Config: r, + Client: client, + BaseURL: Endpoint, + userAgent: ua, + } + return s, nil +} + +var _ sources.Source = &Source{} + +type Source struct { + Config + Client *http.Client + BaseURL string + userAgent string +} + +func (s *Source) SourceKind() string { + return SourceKind +} + +func (s *Source) ToConfig() sources.SourceConfig { + return s.Config +} + +func (s *Source) GetProjectID() string { + return s.ProjectID +} + +func (s *Source) GetBaseURL() string { + return s.BaseURL +} + +func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) { + if s.UseClientOAuth { + if accessToken == "" { + return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided") + } + token := &oauth2.Token{AccessToken: accessToken} + baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)) + baseClient.Transport = util.NewUserAgentRoundTripper(s.userAgent, baseClient.Transport) + return baseClient, nil + } + return s.Client, nil +} + +func (s *Source) UseClientAuthorization() bool { + return s.UseClientOAuth +} diff --git a/internal/sources/cloudgda/cloud_gda_test.go b/internal/sources/cloudgda/cloud_gda_test.go new file mode 100644 index 0000000000..30b977729d --- /dev/null +++ b/internal/sources/cloudgda/cloud_gda_test.go @@ -0,0 +1,213 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudgda_test + +import ( + "context" + "os" + "path/filepath" + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/sources/cloudgda" + "github.com/googleapis/genai-toolbox/internal/testutils" + "go.opentelemetry.io/otel/trace/noop" +) + +func TestParseFromYamlCloudGDA(t *testing.T) { + t.Parallel() + tcs := []struct { + desc string + in string + want server.SourceConfigs + }{ + { + desc: "basic example", + in: ` + sources: + my-gda-instance: + kind: cloud-gemini-data-analytics + projectId: test-project-id + `, + want: map[string]sources.SourceConfig{ + "my-gda-instance": cloudgda.Config{ + Name: "my-gda-instance", + Kind: cloudgda.SourceKind, + ProjectID: "test-project-id", + UseClientOAuth: false, + }, + }, + }, + { + desc: "use client auth example", + in: ` + sources: + my-gda-instance: + kind: cloud-gemini-data-analytics + projectId: another-project + useClientOAuth: true + `, + want: map[string]sources.SourceConfig{ + "my-gda-instance": cloudgda.Config{ + Name: "my-gda-instance", + Kind: cloudgda.SourceKind, + ProjectID: "another-project", + UseClientOAuth: true, + }, + }, + }, + } + for _, tc := range tcs { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if !cmp.Equal(tc.want, got.Sources) { + t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources) + } + }) + } +} + +func TestFailParseFromYaml(t *testing.T) { + t.Parallel() + tcs := []struct { + desc string + in string + err string + }{ + { + desc: "missing projectId", + in: ` + sources: + my-gda-instance: + kind: cloud-gemini-data-analytics + `, + err: "unable to parse source \"my-gda-instance\" as \"cloud-gemini-data-analytics\": Key: 'Config.ProjectID' Error:Field validation for 'ProjectID' failed on the 'required' tag", + }, + } + for _, tc := range tcs { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err == nil { + t.Fatalf("expect parsing to fail") + } + errStr := err.Error() + if errStr != tc.err { + t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err) + } + }) + } +} + +func TestInitialize(t *testing.T) { + // Create a dummy credentials file for testing ADC + credFile := filepath.Join(t.TempDir(), "application_default_credentials.json") + dummyCreds := `{ + "client_id": "foo", + "client_secret": "bar", + "refresh_token": "baz", + "type": "authorized_user" + }` + if err := os.WriteFile(credFile, []byte(dummyCreds), 0644); err != nil { + t.Fatalf("failed to write dummy credentials file: %v", err) + } + t.Setenv("GOOGLE_APPLICATION_CREDENTIALS", credFile) + + // Use ContextWithUserAgent to avoid "unable to retrieve user agent" error + ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent") + tracer := noop.NewTracerProvider().Tracer("test") + + tcs := []struct { + desc string + cfg cloudgda.Config + wantClientOAuth bool + }{ + { + desc: "initialize with ADC", + cfg: cloudgda.Config{Name: "test-gda", Kind: cloudgda.SourceKind, ProjectID: "test-proj"}, + wantClientOAuth: false, + }, + { + desc: "initialize with client OAuth", + cfg: cloudgda.Config{Name: "test-gda-oauth", Kind: cloudgda.SourceKind, ProjectID: "test-proj", UseClientOAuth: true}, + wantClientOAuth: true, + }, + } + + for _, tc := range tcs { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + src, err := tc.cfg.Initialize(ctx, tracer) + if err != nil { + t.Fatalf("failed to initialize source: %v", err) + } + + gdaSrc, ok := src.(*cloudgda.Source) + if !ok { + t.Fatalf("expected *cloudgda.Source, got %T", src) + } + + // Check that the client is non-nil + if gdaSrc.Client == nil && !tc.wantClientOAuth { + t.Fatal("expected non-nil HTTP client for ADC, got nil") + } + // When client OAuth is true, the source's client should be initialized with a base HTTP client + // that includes the user agent round tripper, but not the OAuth token. The token-aware + // client is created by GetClient. + if gdaSrc.Client == nil && tc.wantClientOAuth { + t.Fatal("expected non-nil HTTP client for client OAuth config, got nil") + } + + // Test UseClientAuthorization method + if gdaSrc.UseClientAuthorization() != tc.wantClientOAuth { + t.Errorf("UseClientAuthorization mismatch: want %t, got %t", tc.wantClientOAuth, gdaSrc.UseClientAuthorization()) + } + + // Test GetClient with accessToken for client OAuth scenarios + if tc.wantClientOAuth { + client, err := gdaSrc.GetClient(ctx, "dummy-token") + if err != nil { + t.Fatalf("GetClient with token failed: %v", err) + } + if client == nil { + t.Fatal("expected non-nil HTTP client from GetClient with token, got nil") + } + // Ensure passing empty token with UseClientOAuth enabled returns error + _, err = gdaSrc.GetClient(ctx, "") + if err == nil || err.Error() != "client-side OAuth is enabled but no access token was provided" { + t.Errorf("expected 'client-side OAuth is enabled but no access token was provided' error, got: %v", err) + } + } + }) + } +} diff --git a/internal/sources/cloudmonitoring/cloud_monitoring.go b/internal/sources/cloudmonitoring/cloud_monitoring.go index 4c6db77ed1..d43468687d 100644 --- a/internal/sources/cloudmonitoring/cloud_monitoring.go +++ b/internal/sources/cloudmonitoring/cloud_monitoring.go @@ -29,26 +29,6 @@ import ( const SourceKind string = "cloud-monitoring" -type userAgentRoundTripper struct { - userAgent string - next http.RoundTripper -} - -func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - newReq := *req - newReq.Header = make(http.Header) - for k, v := range req.Header { - newReq.Header[k] = v - } - ua := newReq.Header.Get("User-Agent") - if ua == "" { - newReq.Header.Set("User-Agent", rt.userAgent) - } else { - newReq.Header.Set("User-Agent", ua+" "+rt.userAgent) - } - return rt.next.RoundTrip(&newReq) -} - // validate interface var _ sources.SourceConfig = Config{} @@ -86,10 +66,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So var client *http.Client if r.UseClientOAuth { client = &http.Client{ - Transport: &userAgentRoundTripper{ - userAgent: ua, - next: http.DefaultTransport, - }, + Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport), } } else { // Use Application Default Credentials @@ -98,18 +75,15 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("failed to find default credentials: %w", err) } baseClient := oauth2.NewClient(ctx, creds.TokenSource) - baseClient.Transport = &userAgentRoundTripper{ - userAgent: ua, - next: baseClient.Transport, - } + baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport) client = baseClient } s := &Source{ Config: r, - BaseURL: "https://monitoring.googleapis.com", - Client: client, - UserAgent: ua, + baseURL: "https://monitoring.googleapis.com", + client: client, + userAgent: ua, } return s, nil } @@ -118,9 +92,9 @@ var _ sources.Source = &Source{} type Source struct { Config - BaseURL string `yaml:"baseUrl"` - Client *http.Client - UserAgent string + baseURL string + client *http.Client + userAgent string } func (s *Source) SourceKind() string { @@ -131,6 +105,18 @@ func (s *Source) ToConfig() sources.SourceConfig { return s.Config } +func (s *Source) BaseURL() string { + return s.baseURL +} + +func (s *Source) Client() *http.Client { + return s.client +} + +func (s *Source) UserAgent() string { + return s.userAgent +} + func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) { if s.UseClientOAuth { if accessToken == "" { @@ -139,7 +125,7 @@ func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Clien token := &oauth2.Token{AccessToken: accessToken} return oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)), nil } - return s.Client, nil + return s.client, nil } func (s *Source) UseClientAuthorization() bool { diff --git a/internal/sources/cloudsqladmin/cloud_sql_admin.go b/internal/sources/cloudsqladmin/cloud_sql_admin.go index e0827faf9d..3a3ff48caf 100644 --- a/internal/sources/cloudsqladmin/cloud_sql_admin.go +++ b/internal/sources/cloudsqladmin/cloud_sql_admin.go @@ -30,26 +30,6 @@ import ( const SourceKind string = "cloud-sql-admin" -type userAgentRoundTripper struct { - userAgent string - next http.RoundTripper -} - -func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - newReq := *req - newReq.Header = make(http.Header) - for k, v := range req.Header { - newReq.Header[k] = v - } - ua := newReq.Header.Get("User-Agent") - if ua == "" { - newReq.Header.Set("User-Agent", rt.userAgent) - } else { - newReq.Header.Set("User-Agent", ua+" "+rt.userAgent) - } - return rt.next.RoundTrip(&newReq) -} - // validate interface var _ sources.SourceConfig = Config{} @@ -88,10 +68,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So var client *http.Client if r.UseClientOAuth { client = &http.Client{ - Transport: &userAgentRoundTripper{ - userAgent: ua, - next: http.DefaultTransport, - }, + Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport), } } else { // Use Application Default Credentials @@ -100,10 +77,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("failed to find default credentials: %w", err) } baseClient := oauth2.NewClient(ctx, creds.TokenSource) - baseClient.Transport = &userAgentRoundTripper{ - userAgent: ua, - next: baseClient.Transport, - } + baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport) client = baseClient } @@ -136,6 +110,10 @@ func (s *Source) ToConfig() sources.SourceConfig { return s.Config } +func (s *Source) GetDefaultProject() string { + return s.DefaultProject +} + func (s *Source) GetService(ctx context.Context, accessToken string) (*sqladmin.Service, error) { if s.UseClientOAuth { token := &oauth2.Token{AccessToken: accessToken} diff --git a/internal/sources/http/http.go b/internal/sources/http/http.go index 8f51e84114..b4e9fdd937 100644 --- a/internal/sources/http/http.go +++ b/internal/sources/http/http.go @@ -107,7 +107,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So s := &Source{ Config: r, - Client: &client, + client: &client, } return s, nil @@ -117,7 +117,7 @@ var _ sources.Source = &Source{} type Source struct { Config - Client *http.Client + client *http.Client } func (s *Source) SourceKind() string { @@ -127,3 +127,19 @@ func (s *Source) SourceKind() string { func (s *Source) ToConfig() sources.SourceConfig { return s.Config } + +func (s *Source) HttpDefaultHeaders() map[string]string { + return s.DefaultHeaders +} + +func (s *Source) HttpBaseURL() string { + return s.BaseURL +} + +func (s *Source) HttpQueryParams() map[string]string { + return s.QueryParams +} + +func (s *Source) Client() *http.Client { + return s.client +} diff --git a/internal/sources/looker/looker.go b/internal/sources/looker/looker.go index d88883a7ad..3b60127a55 100644 --- a/internal/sources/looker/looker.go +++ b/internal/sources/looker/looker.go @@ -160,10 +160,6 @@ func (s *Source) ToConfig() sources.SourceConfig { return s.Config } -func (s *Source) GetApiSettings() *rtl.ApiSettings { - return s.ApiSettings -} - func (s *Source) UseClientAuthorization() bool { return strings.ToLower(s.UseClientOAuth) != "false" } @@ -188,6 +184,30 @@ func (s *Source) GoogleCloudTokenSourceWithScope(ctx context.Context, scope stri return google.DefaultTokenSource(ctx, scope) } +func (s *Source) LookerClient() *v4.LookerSDK { + return s.Client +} + +func (s *Source) LookerApiSettings() *rtl.ApiSettings { + return s.ApiSettings +} + +func (s *Source) LookerShowHiddenFields() bool { + return s.ShowHiddenFields +} + +func (s *Source) LookerShowHiddenModels() bool { + return s.ShowHiddenModels +} + +func (s *Source) LookerShowHiddenExplores() bool { + return s.ShowHiddenExplores +} + +func (s *Source) LookerSessionLength() int64 { + return s.SessionLength +} + func initGoogleCloudConnection(ctx context.Context) (oauth2.TokenSource, error) { cred, err := google.FindDefaultCredentials(ctx, geminidataanalytics.DefaultAuthScopes()...) if err != nil { diff --git a/internal/sources/oracle/oracle.go b/internal/sources/oracle/oracle.go index 3b37560004..4de64b402b 100644 --- a/internal/sources/oracle/oracle.go +++ b/internal/sources/oracle/oracle.go @@ -9,9 +9,11 @@ import ( "strings" "github.com/goccy/go-yaml" + _ "github.com/godror/godror" // OCI driver + _ "github.com/sijms/go-ora/v2" // Pure Go driver + "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" - _ "github.com/sijms/go-ora/v2" "go.opentelemetry.io/otel/trace" ) @@ -32,7 +34,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources return nil, err } - // Validate that we have one of: tns_alias, connection_string, or host+service_name + // Validate that we have one of: tnsAlias, connectionString, or host+service_name if err := actual.validate(); err != nil { return nil, fmt.Errorf("invalid Oracle configuration: %w", err) } @@ -43,21 +45,24 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` - ConnectionString string `yaml:"connectionString,omitempty"` // Direct connection string (hostname[:port]/servicename) - TnsAlias string `yaml:"tnsAlias,omitempty"` // TNS alias from tnsnames.ora - Host string `yaml:"host,omitempty"` // Optional when using connectionString/tnsAlias - Port int `yaml:"port,omitempty"` // Explicit port support - ServiceName string `yaml:"serviceName,omitempty"` // Optional when using connectionString/tnsAlias + ConnectionString string `yaml:"connectionString,omitempty"` + TnsAlias string `yaml:"tnsAlias,omitempty"` + TnsAdmin string `yaml:"tnsAdmin,omitempty"` + Host string `yaml:"host,omitempty"` + Port int `yaml:"port,omitempty"` + ServiceName string `yaml:"serviceName,omitempty"` User string `yaml:"user" validate:"required"` Password string `yaml:"password" validate:"required"` - TnsAdmin string `yaml:"tnsAdmin,omitempty"` // Optional: override TNS_ADMIN environment variable + UseOCI bool `yaml:"useOCI,omitempty"` + WalletLocation string `yaml:"walletLocation,omitempty"` } -// validate ensures we have one of: tns_alias, connection_string, or host+service_name func (c Config) validate() error { + hasTnsAdmin := strings.TrimSpace(c.TnsAdmin) != "" hasTnsAlias := strings.TrimSpace(c.TnsAlias) != "" hasConnStr := strings.TrimSpace(c.ConnectionString) != "" hasHostService := strings.TrimSpace(c.Host) != "" && strings.TrimSpace(c.ServiceName) != "" + hasWallet := strings.TrimSpace(c.WalletLocation) != "" connectionMethods := 0 if hasTnsAlias { @@ -78,6 +83,14 @@ func (c Config) validate() error { return fmt.Errorf("provide only one connection method: 'tns_alias', 'connection_string', or 'host'+'service_name'") } + if hasTnsAdmin && !c.UseOCI { + return fmt.Errorf("`tnsAdmin` can only be used when `UseOCI` is true, or use `walletLocation` instead") + } + + if hasWallet && c.UseOCI { + return fmt.Errorf("when using an OCI driver, use `tnsAdmin` to specify credentials file location instead") + } + return nil } @@ -132,7 +145,8 @@ func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Confi panic(err) } - // Set TNS_ADMIN environment variable if specified in config. + hasWallet := strings.TrimSpace(config.WalletLocation) != "" + if config.TnsAdmin != "" { originalTnsAdmin := os.Getenv("TNS_ADMIN") os.Setenv("TNS_ADMIN", config.TnsAdmin) @@ -147,28 +161,49 @@ func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Confi }() } - var serverString string + var connectStringBase string if config.TnsAlias != "" { - // Use TNS alias - serverString = strings.TrimSpace(config.TnsAlias) + connectStringBase = strings.TrimSpace(config.TnsAlias) } else if config.ConnectionString != "" { - // Use provided connection string directly (hostname[:port]/servicename format) - serverString = strings.TrimSpace(config.ConnectionString) + connectStringBase = strings.TrimSpace(config.ConnectionString) } else { - // Build connection string from host and service_name if config.Port > 0 { - serverString = fmt.Sprintf("%s:%d/%s", config.Host, config.Port, config.ServiceName) + connectStringBase = fmt.Sprintf("%s:%d/%s", config.Host, config.Port, config.ServiceName) } else { - serverString = fmt.Sprintf("%s/%s", config.Host, config.ServiceName) + connectStringBase = fmt.Sprintf("%s/%s", config.Host, config.ServiceName) } } - connStr := fmt.Sprintf("oracle://%s:%s@%s", - config.User, config.Password, serverString) + var driverName string + var finalConnStr string - db, err := sql.Open("oracle", connStr) + if config.UseOCI { + // Use godror driver (requires OCI) + driverName = "godror" + finalConnStr = fmt.Sprintf(`user="%s" password="%s" connectString="%s"`, + config.User, config.Password, connectStringBase) + logger.DebugContext(ctx, fmt.Sprintf("Using godror driver (OCI-based) with connectString: %s\n", connectStringBase)) + } else { + // Use go-ora driver (pure Go) + driverName = "oracle" + + user := config.User + password := config.Password + + if hasWallet { + finalConnStr = fmt.Sprintf("oracle://%s:%s@%s?ssl=true&wallet=%s", + user, password, connectStringBase, config.WalletLocation) + } else { + // Standard go-ora connection + finalConnStr = fmt.Sprintf("oracle://%s:%s@%s", + config.User, config.Password, connectStringBase) + logger.DebugContext(ctx, fmt.Sprintf("Using go-ora driver (pure-Go) with serverString: %s\n", connectStringBase)) + } + } + + db, err := sql.Open(driverName, finalConnStr) if err != nil { - return nil, fmt.Errorf("unable to open Oracle connection: %w", err) + return nil, fmt.Errorf("unable to open Oracle connection with driver %s: %w", driverName, err) } return db, nil diff --git a/internal/sources/oracle/oracle_test.go b/internal/sources/oracle/oracle_test.go new file mode 100644 index 0000000000..3d8f4c7ba5 --- /dev/null +++ b/internal/sources/oracle/oracle_test.go @@ -0,0 +1,200 @@ +// Copyright © 2025, Oracle and/or its affiliates. + +package oracle_test + +import ( + "strings" + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/sources/oracle" + "github.com/googleapis/genai-toolbox/internal/testutils" +) + +func TestParseFromYamlOracle(t *testing.T) { + tcs := []struct { + desc string + in string + want server.SourceConfigs + }{ + { + desc: "connection string and useOCI=true", + in: ` + sources: + my-oracle-cs: + kind: oracle + connectionString: "my-host:1521/XEPDB1" + user: my_user + password: my_pass + useOCI: true + `, + want: server.SourceConfigs{ + "my-oracle-cs": oracle.Config{ + Name: "my-oracle-cs", + Kind: oracle.SourceKind, + ConnectionString: "my-host:1521/XEPDB1", + User: "my_user", + Password: "my_pass", + UseOCI: true, + }, + }, + }, + { + desc: "host/port/serviceName and default useOCI=false", + in: ` + sources: + my-oracle-host: + kind: oracle + host: my-host + port: 1521 + serviceName: ORCLPDB + user: my_user + password: my_pass + `, + want: server.SourceConfigs{ + "my-oracle-host": oracle.Config{ + Name: "my-oracle-host", + Kind: oracle.SourceKind, + Host: "my-host", + Port: 1521, + ServiceName: "ORCLPDB", + User: "my_user", + Password: "my_pass", + UseOCI: false, + }, + }, + }, + { + desc: "tnsAlias and TnsAdmin specified with explicit useOCI=true", + in: ` + sources: + my-oracle-tns-oci: + kind: oracle + tnsAlias: FINANCE_DB + tnsAdmin: /opt/oracle/network/admin + user: my_user + password: my_pass + useOCI: true + `, + want: server.SourceConfigs{ + "my-oracle-tns-oci": oracle.Config{ + Name: "my-oracle-tns-oci", + Kind: oracle.SourceKind, + TnsAlias: "FINANCE_DB", + TnsAdmin: "/opt/oracle/network/admin", + User: "my_user", + Password: "my_pass", + UseOCI: true, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if !cmp.Equal(tc.want, got.Sources) { + t.Fatalf("incorrect parse:\nwant: %v\ngot: %v\ndiff: %s", tc.want, got.Sources, cmp.Diff(tc.want, got.Sources)) + } + }) + } +} + +func TestFailParseFromYamlOracle(t *testing.T) { + tcs := []struct { + desc string + in string + err string + }{ + { + desc: "extra field", + in: ` + sources: + my-oracle-instance: + kind: oracle + host: my-host + serviceName: ORCL + user: my_user + password: my_pass + extraField: value + `, + err: "unable to parse source \"my-oracle-instance\" as \"oracle\": [1:1] unknown field \"extraField\"\n> 1 | extraField: value\n ^\n 2 | host: my-host\n 3 | kind: oracle\n 4 | password: my_pass\n 5 | ", + }, + { + desc: "missing required password field", + in: ` + sources: + my-oracle-instance: + kind: oracle + host: my-host + serviceName: ORCL + user: my_user + `, + err: "unable to parse source \"my-oracle-instance\" as \"oracle\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag", + }, + { + desc: "missing connection method fields (validate fails)", + in: ` + sources: + my-oracle-instance: + kind: oracle + user: my_user + password: my_pass + `, + err: "unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: must provide one of: 'tns_alias', 'connection_string', or both 'host' and 'service_name'", + }, + { + desc: "multiple connection methods provided (validate fails)", + in: ` + sources: + my-oracle-instance: + kind: oracle + host: my-host + serviceName: ORCL + connectionString: "my-host:1521/XEPDB1" + user: my_user + password: my_pass + `, + err: "unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: provide only one connection method: 'tns_alias', 'connection_string', or 'host'+'service_name'", + }, + { + desc: "fail on tnsAdmin with useOCI=false", + in: ` + sources: + my-oracle-fail: + kind: oracle + tnsAlias: FINANCE_DB + tnsAdmin: /opt/oracle/network/admin + user: my_user + password: my_pass + useOCI: false + `, + err: "unable to parse source \"my-oracle-fail\" as \"oracle\": invalid Oracle configuration: `tnsAdmin` can only be used when `UseOCI` is true, or use `walletLocation` instead", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err == nil { + t.Fatalf("expect parsing to fail") + } + errStr := strings.ReplaceAll(err.Error(), "\r", "") + + if errStr != tc.err { + t.Fatalf("unexpected error:\ngot:\n%q\nwant:\n%q\n", errStr, tc.err) + } + }) + } +} diff --git a/internal/sources/serverlessspark/serverlessspark.go b/internal/sources/serverlessspark/serverlessspark.go index 2e95199ecd..c63adb6863 100644 --- a/internal/sources/serverlessspark/serverlessspark.go +++ b/internal/sources/serverlessspark/serverlessspark.go @@ -96,6 +96,14 @@ func (s *Source) ToConfig() sources.SourceConfig { return s.Config } +func (s *Source) GetProject() string { + return s.Project +} + +func (s *Source) GetLocation() string { + return s.Location +} + func (s *Source) GetBatchControllerClient() *dataproc.BatchControllerClient { return s.Client } diff --git a/internal/testutils/testutils.go b/internal/testutils/testutils.go index f6c30b859c..82975fb321 100644 --- a/internal/testutils/testutils.go +++ b/internal/testutils/testutils.go @@ -46,6 +46,11 @@ func ContextWithNewLogger() (context.Context, error) { return util.WithLogger(ctx, logger), nil } +// ContextWithUserAgent creates a new context with a specified user agent string. +func ContextWithUserAgent(ctx context.Context, userAgent string) context.Context { + return util.WithUserAgent(ctx, userAgent) +} + // WaitForString waits until the server logs a single line that matches the provided regex. // returns the output of whatever the server sent so far. func WaitForString(ctx context.Context, re *regexp.Regexp, pr io.ReadCloser) (string, error) { diff --git a/internal/tools/alloydb/alloydbcreatecluster/alloydbcreatecluster.go b/internal/tools/alloydb/alloydbcreatecluster/alloydbcreatecluster.go index eeec42b655..0993efd1da 100644 --- a/internal/tools/alloydb/alloydbcreatecluster/alloydbcreatecluster.go +++ b/internal/tools/alloydb/alloydbcreatecluster/alloydbcreatecluster.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/alloydb/v1" @@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + UseClientAuthorization() bool + GetService(context.Context, string) (*alloydb.Service, error) +} + // Configuration for the create-cluster tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("source %q not found", cfg.Source) } - s, ok := rawS.(*alloydbadmin.Source) + s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -97,7 +102,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -107,7 +111,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the create-cluster tool. type Tool struct { Config - Source *alloydbadmin.Source AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest @@ -120,6 +123,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok || project == "" { @@ -151,7 +159,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid 'user' parameter; expected a string") } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -198,10 +206,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/alloydb/alloydbcreateinstance/alloydbcreateinstance.go b/internal/tools/alloydb/alloydbcreateinstance/alloydbcreateinstance.go index 2c6344a2b1..6d3382c516 100644 --- a/internal/tools/alloydb/alloydbcreateinstance/alloydbcreateinstance.go +++ b/internal/tools/alloydb/alloydbcreateinstance/alloydbcreateinstance.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/alloydb/v1" @@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + UseClientAuthorization() bool + GetService(context.Context, string) (*alloydb.Service, error) +} + // Configuration for the create-instance tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("source %q not found", cfg.Source) } - s, ok := rawS.(*alloydbadmin.Source) + s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -98,7 +103,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -108,7 +112,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the create-instance tool. type Tool struct { Config - Source *alloydbadmin.Source AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest @@ -121,6 +124,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok || project == "" { @@ -147,7 +155,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid 'instanceType' parameter; expected 'PRIMARY' or 'READ_POOL'") } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -208,10 +216,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/alloydb/alloydbcreateuser/alloydbcreateuser.go b/internal/tools/alloydb/alloydbcreateuser/alloydbcreateuser.go index 873995e547..921dc500fc 100644 --- a/internal/tools/alloydb/alloydbcreateuser/alloydbcreateuser.go +++ b/internal/tools/alloydb/alloydbcreateuser/alloydbcreateuser.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/alloydb/v1" @@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + UseClientAuthorization() bool + GetService(context.Context, string) (*alloydb.Service, error) +} + // Configuration for the create-user tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("source %q not found", cfg.Source) } - s, ok := rawS.(*alloydbadmin.Source) + s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -98,7 +103,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -108,9 +112,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the create-user tool. type Tool struct { Config - Source *alloydbadmin.Source - AllParams parameters.Parameters `yaml:"allParams"` - + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -121,6 +123,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok || project == "" { @@ -147,7 +154,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid or missing 'userType' parameter; expected 'ALLOYDB_BUILT_IN' or 'ALLOYDB_IAM_USER'") } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -208,10 +215,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/alloydb/alloydbgetcluster/alloydbgetcluster.go b/internal/tools/alloydb/alloydbgetcluster/alloydbgetcluster.go index 30cf291bea..77683d1481 100644 --- a/internal/tools/alloydb/alloydbgetcluster/alloydbgetcluster.go +++ b/internal/tools/alloydb/alloydbgetcluster/alloydbgetcluster.go @@ -20,9 +20,9 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "google.golang.org/api/alloydb/v1" ) const kind string = "alloydb-get-cluster" @@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + UseClientAuthorization() bool + GetService(context.Context, string) (*alloydb.Service, error) +} + // Configuration for the get-cluster tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("source %q not found", cfg.Source) } - s, ok := rawS.(*alloydbadmin.Source) + s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -104,7 +109,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the get-cluster tool. type Tool struct { Config - Source *alloydbadmin.Source AllParams parameters.Parameters manifest tools.Manifest @@ -117,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -132,7 +141,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string") } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -167,10 +176,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/alloydb/alloydbgetinstance/alloydbgetinstance.go b/internal/tools/alloydb/alloydbgetinstance/alloydbgetinstance.go index 44dbb7d42d..ed67ed54c4 100644 --- a/internal/tools/alloydb/alloydbgetinstance/alloydbgetinstance.go +++ b/internal/tools/alloydb/alloydbgetinstance/alloydbgetinstance.go @@ -20,9 +20,9 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "google.golang.org/api/alloydb/v1" ) const kind string = "alloydb-get-instance" @@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + UseClientAuthorization() bool + GetService(context.Context, string) (*alloydb.Service, error) +} + // Configuration for the get-instance tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("source %q not found", cfg.Source) } - s, ok := rawS.(*alloydbadmin.Source) + s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -95,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -105,9 +110,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the get-instance tool. type Tool struct { Config - Source *alloydbadmin.Source - AllParams parameters.Parameters - + AllParams parameters.Parameters manifest tools.Manifest mcpManifest tools.McpManifest } @@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -137,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid 'instance' parameter; expected a string") } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -172,10 +180,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/alloydb/alloydbgetuser/alloydbgetuser.go b/internal/tools/alloydb/alloydbgetuser/alloydbgetuser.go index 7c33bd340c..d21a984e02 100644 --- a/internal/tools/alloydb/alloydbgetuser/alloydbgetuser.go +++ b/internal/tools/alloydb/alloydbgetuser/alloydbgetuser.go @@ -20,9 +20,9 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "google.golang.org/api/alloydb/v1" ) const kind string = "alloydb-get-user" @@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + UseClientAuthorization() bool + GetService(context.Context, string) (*alloydb.Service, error) +} + // Configuration for the get-user tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("source %q not found", cfg.Source) } - s, ok := rawS.(*alloydbadmin.Source) + s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -95,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -105,9 +110,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the get-user tool. type Tool struct { Config - Source *alloydbadmin.Source - AllParams parameters.Parameters - + AllParams parameters.Parameters manifest tools.Manifest mcpManifest tools.McpManifest } @@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -137,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid 'user' parameter; expected a string") } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -172,10 +180,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/alloydb/alloydblistclusters/alloydblistclusters.go b/internal/tools/alloydb/alloydblistclusters/alloydblistclusters.go index eab2c4a7e8..1b29b9a37a 100644 --- a/internal/tools/alloydb/alloydblistclusters/alloydblistclusters.go +++ b/internal/tools/alloydb/alloydblistclusters/alloydblistclusters.go @@ -20,9 +20,9 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "google.golang.org/api/alloydb/v1" ) const kind string = "alloydb-list-clusters" @@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + UseClientAuthorization() bool + GetService(context.Context, string) (*alloydb.Service, error) +} + // Configuration for the list-clusters tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("source %q not found", cfg.Source) } - s, ok := rawS.(*alloydbadmin.Source) + s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -93,7 +99,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -103,9 +108,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the list-clusters tool. type Tool struct { Config - Source *alloydbadmin.Source - AllParams parameters.Parameters `yaml:"allParams"` - + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -116,6 +119,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -127,7 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid 'location' parameter; expected a string") } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -162,10 +170,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/alloydb/alloydblistinstances/alloydblistinstances.go b/internal/tools/alloydb/alloydblistinstances/alloydblistinstances.go index 02e8d026a5..7448241738 100644 --- a/internal/tools/alloydb/alloydblistinstances/alloydblistinstances.go +++ b/internal/tools/alloydb/alloydblistinstances/alloydblistinstances.go @@ -20,9 +20,9 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "google.golang.org/api/alloydb/v1" ) const kind string = "alloydb-list-instances" @@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + UseClientAuthorization() bool + GetService(context.Context, string) (*alloydb.Service, error) +} + // Configuration for the list-instances tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("source %q not found", cfg.Source) } - s, ok := rawS.(*alloydbadmin.Source) + s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -104,9 +109,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the list-instances tool. type Tool struct { Config - Source *alloydbadmin.Source - AllParams parameters.Parameters `yaml:"allParams"` - + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -117,6 +120,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -132,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string") } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -167,10 +175,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/alloydb/alloydblistusers/alloydblistusers.go b/internal/tools/alloydb/alloydblistusers/alloydblistusers.go index c33a982382..c14d9bea5c 100644 --- a/internal/tools/alloydb/alloydblistusers/alloydblistusers.go +++ b/internal/tools/alloydb/alloydblistusers/alloydblistusers.go @@ -20,9 +20,9 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "google.golang.org/api/alloydb/v1" ) const kind string = "alloydb-list-users" @@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + UseClientAuthorization() bool + GetService(context.Context, string) (*alloydb.Service, error) +} + // Configuration for the list-users tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("source %q not found", cfg.Source) } - s, ok := rawS.(*alloydbadmin.Source) + s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -104,9 +109,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the list-users tool. type Tool struct { Config - Source *alloydbadmin.Source - AllParams parameters.Parameters `yaml:"allParams"` - + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -117,6 +120,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -132,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string") } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -167,10 +175,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/alloydb/alloydbwaitforoperation/alloydbwaitforoperation.go b/internal/tools/alloydb/alloydbwaitforoperation/alloydbwaitforoperation.go index dd909c42c3..f2de0b37d0 100644 --- a/internal/tools/alloydb/alloydbwaitforoperation/alloydbwaitforoperation.go +++ b/internal/tools/alloydb/alloydbwaitforoperation/alloydbwaitforoperation.go @@ -25,9 +25,9 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "google.golang.org/api/alloydb/v1" ) const kind string = "alloydb-wait-for-operation" @@ -89,6 +89,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + UseClientAuthorization() bool + GetService(context.Context, string) (*alloydb.Service, error) +} + // Config defines the configuration for the wait-for-operation tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -119,12 +125,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*alloydbadmin.Source) + s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -180,7 +186,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -194,19 +199,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the wait-for-operation tool. type Tool struct { Config - - Source *alloydbadmin.Source - AllParams parameters.Parameters `yaml:"allParams"` + AllParams parameters.Parameters `yaml:"allParams"` + Client *http.Client + manifest tools.Manifest + mcpManifest tools.McpManifest // Polling configuration Delay time.Duration MaxDelay time.Duration Multiplier float64 MaxRetries int - - Client *http.Client - manifest tools.Manifest - mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -215,6 +217,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -230,7 +237,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("missing 'operation' parameter") } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -363,10 +370,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/alloydbainl/alloydbainl.go b/internal/tools/alloydbainl/alloydbainl.go index 39564680ad..3c94860e53 100644 --- a/internal/tools/alloydbainl/alloydbainl.go +++ b/internal/tools/alloydbainl/alloydbainl.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -47,11 +46,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -70,18 +64,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - numParams := len(cfg.NLConfigParameters) quotedNameParts := make([]string, 0, numParams) placeholderParts := make([]string, 0, numParams) @@ -126,7 +108,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) Config: cfg, Parameters: cfg.NLConfigParameters, Statement: stmt, - Pool: s.PostgresPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.NLConfigParameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -139,9 +120,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Pool *pgxpool.Pool + Parameters parameters.Parameters `yaml:"parameters"` Statement string manifest tools.Manifest mcpManifest tools.McpManifest @@ -152,6 +131,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + pool := source.PostgresPool() + sliceParams := params.AsSlice() allParamValues := make([]any, len(sliceParams)+1) allParamValues[0] = fmt.Sprintf("%s", sliceParams[0]) // nl_question @@ -160,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para allParamValues[i+2] = fmt.Sprintf("%s", param) } - results, err := t.Pool.Query(ctx, t.Statement, allParamValues...) + results, err := pool.Query(ctx, t.Statement, allParamValues...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w. Query: %v , Values: %v. Toolbox v0.19.0+ is only compatible with AlloyDB AI NL v1.0.3+. Please ensure that you are using the latest AlloyDB AI NL extension", err, t.Statement, allParamValues) } @@ -203,10 +188,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go index 6fe64b28e2..61b90a1d11 100644 --- a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go +++ b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go @@ -57,11 +57,6 @@ type compatibleSource interface { BigQuerySession() bigqueryds.BigQuerySessionProvider } -// validate compatible sources are still compatible -var _ compatibleSource = &bigqueryds.Source{} - -var compatibleSources = [...]string{bigqueryds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -87,7 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } allowedDatasets := s.BigQueryAllowedDatasets() @@ -136,17 +131,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - ClientCreator: s.BigQueryClientCreator(), - Client: s.BigQueryClient(), - RestService: s.BigQueryRestService(), - IsDatasetAllowed: s.IsDatasetAllowed, - AllowedDatasets: allowedDatasets, - SessionProvider: s.BigQuerySession(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -156,17 +144,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Client *bigqueryapi.Client - RestService *bigqueryrestapi.Service - ClientCreator bigqueryds.BigqueryClientCreator - IsDatasetAllowed func(projectID, datasetID string) bool - AllowedDatasets []string - SessionProvider bigqueryds.BigQuerySessionProvider - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -175,23 +155,27 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke runs the contribution analysis. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() inputData, ok := paramsMap["input_data"].(string) if !ok { return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"]) } - bqClient := t.Client - restService := t.RestService - var err error + bqClient := source.BigQueryClient() + restService := source.BigQueryRestService() // Initialize new client if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - bqClient, restService, err = t.ClientCreator(tokenStr, true) + bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true) if err != nil { return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) } @@ -229,9 +213,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para var inputDataSource string trimmedUpperInputData := strings.TrimSpace(strings.ToUpper(inputData)) if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") { - if len(t.AllowedDatasets) > 0 { + if len(source.BigQueryAllowedDatasets()) > 0 { var connProps []*bigqueryapi.ConnectionProperty - session, err := t.SessionProvider(ctx) + session, err := source.BigQuerySession()(ctx) if err != nil { return nil, fmt.Errorf("failed to get BigQuery session: %w", err) } @@ -240,7 +224,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para {Key: "session_id", Value: session.ID}, } } - dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, inputData, nil, connProps) + dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps) if err != nil { return nil, fmt.Errorf("query validation failed: %w", err) } @@ -252,7 +236,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para queryStats := dryRunJob.Statistics.Query if queryStats != nil { for _, tableRef := range queryStats.ReferencedTables { - if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) { + if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) { return nil, fmt.Errorf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId) } } @@ -262,18 +246,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } inputDataSource = fmt.Sprintf("(%s)", inputData) } else { - if len(t.AllowedDatasets) > 0 { + if len(source.BigQueryAllowedDatasets()) > 0 { parts := strings.Split(inputData, ".") var projectID, datasetID string switch len(parts) { case 3: // project.dataset.table projectID, datasetID = parts[0], parts[1] case 2: // dataset.table - projectID, datasetID = t.Client.Project(), parts[0] + projectID, datasetID = source.BigQueryClient().Project(), parts[0] default: return nil, fmt.Errorf("invalid table ID format for 'input_data': %q. Expected 'dataset.table' or 'project.dataset.table'", inputData) } - if !t.IsDatasetAllowed(projectID, datasetID) { + if !source.IsDatasetAllowed(projectID, datasetID) { return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, inputData) } } @@ -292,7 +276,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Get session from provider if in protected mode. // Otherwise, a new session will be created by the first query. - session, err := t.SessionProvider(ctx) + session, err := source.BigQuerySession()(ctx) if err != nil { return nil, fmt.Errorf("failed to get BigQuery session: %w", err) } @@ -385,10 +369,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go b/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go index 945ac9fe5a..6d54f000b1 100644 --- a/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go +++ b/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go @@ -26,7 +26,6 @@ import ( bigqueryapi "cloud.google.com/go/bigquery" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -105,11 +104,6 @@ type CAPayload struct { ClientIdEnum string `json:"clientIdEnum"` } -// validate compatible sources are still compatible -var _ compatibleSource = &bigqueryds.Source{} - -var compatibleSources = [...]string{bigqueryds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -135,7 +129,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } allowedDatasets := s.BigQueryAllowedDatasets() @@ -153,31 +147,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) params := parameters.Parameters{userQueryParameter, tableRefsParameter} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) - // Get cloud-platform token source for Gemini Data Analytics API during initialization - var bigQueryTokenSourceWithScope oauth2.TokenSource - if !s.UseClientAuthorization() { - ctx := context.Background() - ts, err := s.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform") - if err != nil { - return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err) - } - bigQueryTokenSourceWithScope = ts - } - // finish tool setup t := Tool{ - Config: cfg, - Project: s.BigQueryProject(), - Location: s.BigQueryLocation(), - Parameters: params, - Client: s.BigQueryClient(), - UseClientOAuth: s.UseClientAuthorization(), - TokenSource: bigQueryTokenSourceWithScope, - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, - MaxQueryResultRows: s.GetMaxQueryResultRows(), - IsDatasetAllowed: s.IsDatasetAllowed, - AllowedDatasets: allowedDatasets, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -187,18 +162,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project string - Location string - Client *bigqueryapi.Client - TokenSource oauth2.TokenSource - manifest tools.Manifest - mcpManifest tools.McpManifest - MaxQueryResultRows int - IsDatasetAllowed func(projectID, datasetID string) bool - AllowedDatasets []string + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -206,11 +172,15 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + var tokenStr string - var err error // Get credentials for the API call - if t.UseClientOAuth { + if source.UseClientAuthorization() { // Use client-side access token if accessToken == "" { return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", util.ErrUnauthorized) @@ -220,11 +190,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error parsing access token: %w", err) } } else { + // Get cloud-platform token source for Gemini Data Analytics API during initialization + tokenSource, err := source.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform") + if err != nil { + return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err) + } + // Use cloud-platform token source for Gemini Data Analytics API - if t.TokenSource == nil { + if tokenSource == nil { return nil, fmt.Errorf("cloud-platform token source is missing") } - token, err := t.TokenSource.Token() + token, err := tokenSource.Token() if err != nil { return nil, fmt.Errorf("failed to get token from cloud-platform token source: %w", err) } @@ -245,17 +221,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - if len(t.AllowedDatasets) > 0 { + if len(source.BigQueryAllowedDatasets()) > 0 { for _, tableRef := range tableRefs { - if !t.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) { + if !source.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) { return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID) } } } // Construct URL, headers, and payload - projectID := t.Project - location := t.Location + projectID := source.BigQueryProject() + location := source.BigQueryLocation() if location == "" { location = "us" } @@ -279,7 +255,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } // Call the streaming API - response, err := getStream(caURL, payload, headers, t.MaxQueryResultRows) + response, err := getStream(caURL, payload, headers, source.GetMaxQueryResultRows()) if err != nil { return nil, fmt.Errorf("failed to get response from conversational analytics API: %w", err) } @@ -303,8 +279,12 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } // StreamMessage represents a single message object from the streaming API response. @@ -580,6 +560,6 @@ func appendMessage(messages []map[string]any, newMessage map[string]any) []map[s return append(messages, newMessage) } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go index b3e4d04f16..a70d4d342d 100644 --- a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go +++ b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go @@ -60,11 +60,6 @@ type compatibleSource interface { BigQueryAllowedDatasets() []string } -// validate compatible sources are still compatible -var _ compatibleSource = &bigqueryds.Source{} - -var compatibleSources = [...]string{bigqueryds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -90,7 +85,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } var sqlDescriptionBuilder strings.Builder @@ -136,18 +131,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - ClientCreator: s.BigQueryClientCreator(), - Client: s.BigQueryClient(), - RestService: s.BigQueryRestService(), - WriteMode: s.BigQueryWriteMode(), - SessionProvider: s.BigQuerySession(), - IsDatasetAllowed: s.IsDatasetAllowed, - AllowedDatasets: allowedDatasets, - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -157,18 +144,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Client *bigqueryapi.Client - RestService *bigqueryrestapi.Service - WriteMode string - SessionProvider bigqueryds.BigQuerySessionProvider - ClientCreator bigqueryds.BigqueryClientCreator - IsDatasetAllowed func(projectID, datasetID string) bool - AllowedDatasets []string - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -176,6 +154,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { @@ -186,17 +169,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"]) } - bqClient := t.Client - restService := t.RestService + bqClient := source.BigQueryClient() + restService := source.BigQueryRestService() - var err error // Initialize new client if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - bqClient, restService, err = t.ClientCreator(tokenStr, true) + bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true) if err != nil { return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) } @@ -204,8 +186,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para var connProps []*bigqueryapi.ConnectionProperty var session *bigqueryds.Session - if t.WriteMode == bigqueryds.WriteModeProtected { - session, err = t.SessionProvider(ctx) + if source.BigQueryWriteMode() == bigqueryds.WriteModeProtected { + session, err = source.BigQuerySession()(ctx) if err != nil { return nil, fmt.Errorf("failed to get BigQuery session for protected mode: %w", err) } @@ -221,7 +203,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para statementType := dryRunJob.Statistics.Query.StatementType - switch t.WriteMode { + switch source.BigQueryWriteMode() { case bigqueryds.WriteModeBlocked: if statementType != "SELECT" { return nil, fmt.Errorf("write mode is 'blocked', only SELECT statements are allowed") @@ -235,7 +217,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - if len(t.AllowedDatasets) > 0 { + if len(source.BigQueryAllowedDatasets()) > 0 { switch statementType { case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA": return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType) @@ -270,7 +252,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } else if statementType != "SELECT" { // If dry run yields no tables, fall back to the parser for non-SELECT statements // to catch unsafe operations like EXECUTE IMMEDIATE. - parsedTables, parseErr := bqutil.TableParser(sql, t.Client.Project()) + parsedTables, parseErr := bqutil.TableParser(sql, source.BigQueryClient().Project()) if parseErr != nil { // If parsing fails (e.g., EXECUTE IMMEDIATE), we cannot guarantee safety, so we must fail. return nil, fmt.Errorf("could not parse tables from query to validate against allowed datasets: %w", parseErr) @@ -282,7 +264,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para parts := strings.Split(tableID, ".") if len(parts) == 3 { projectID, datasetID := parts[0], parts[1] - if !t.IsDatasetAllowed(projectID, datasetID) { + if !source.IsDatasetAllowed(projectID, datasetID) { return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID) } } @@ -374,10 +356,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go index 583bc51df1..034bce3501 100644 --- a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go +++ b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go @@ -57,11 +57,6 @@ type compatibleSource interface { BigQuerySession() bigqueryds.BigQuerySessionProvider } -// validate compatible sources are still compatible -var _ compatibleSource = &bigqueryds.Source{} - -var compatibleSources = [...]string{bigqueryds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -87,7 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } allowedDatasets := s.BigQueryAllowedDatasets() @@ -116,17 +111,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - ClientCreator: s.BigQueryClientCreator(), - Client: s.BigQueryClient(), - RestService: s.BigQueryRestService(), - IsDatasetAllowed: s.IsDatasetAllowed, - SessionProvider: s.BigQuerySession(), - AllowedDatasets: allowedDatasets, - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -136,17 +124,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Client *bigqueryapi.Client - RestService *bigqueryrestapi.Service - ClientCreator bigqueryds.BigqueryClientCreator - IsDatasetAllowed func(projectID, datasetID string) bool - AllowedDatasets []string - SessionProvider bigqueryds.BigQuerySessionProvider - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -154,6 +134,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() historyData, ok := paramsMap["history_data"].(string) if !ok { @@ -188,17 +173,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - bqClient := t.Client - restService := t.RestService - var err error + bqClient := source.BigQueryClient() + restService := source.BigQueryRestService() // Initialize new client if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - bqClient, restService, err = t.ClientCreator(tokenStr, false) + bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, false) if err != nil { return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) } @@ -207,9 +191,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para var historyDataSource string trimmedUpperHistoryData := strings.TrimSpace(strings.ToUpper(historyData)) if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") { - if len(t.AllowedDatasets) > 0 { + if len(source.BigQueryAllowedDatasets()) > 0 { var connProps []*bigqueryapi.ConnectionProperty - session, err := t.SessionProvider(ctx) + session, err := source.BigQuerySession()(ctx) if err != nil { return nil, fmt.Errorf("failed to get BigQuery session: %w", err) } @@ -218,7 +202,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para {Key: "session_id", Value: session.ID}, } } - dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, historyData, nil, connProps) + dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps) if err != nil { return nil, fmt.Errorf("query validation failed: %w", err) } @@ -230,7 +214,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para queryStats := dryRunJob.Statistics.Query if queryStats != nil { for _, tableRef := range queryStats.ReferencedTables { - if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) { + if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) { return nil, fmt.Errorf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId) } } @@ -240,7 +224,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } historyDataSource = fmt.Sprintf("(%s)", historyData) } else { - if len(t.AllowedDatasets) > 0 { + if len(source.BigQueryAllowedDatasets()) > 0 { parts := strings.Split(historyData, ".") var projectID, datasetID string @@ -249,13 +233,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para projectID = parts[0] datasetID = parts[1] case 2: // dataset.table - projectID = t.Client.Project() + projectID = source.BigQueryClient().Project() datasetID = parts[0] default: return nil, fmt.Errorf("invalid table ID format for 'history_data': %q. Expected 'dataset.table' or 'project.dataset.table'", historyData) } - if !t.IsDatasetAllowed(projectID, datasetID) { + if !source.IsDatasetAllowed(projectID, datasetID) { return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, historyData) } } @@ -279,7 +263,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // JobStatistics.QueryStatistics.StatementType query := bqClient.Query(sql) query.Location = bqClient.Location - session, err := t.SessionProvider(ctx) + session, err := source.BigQuerySession()(ctx) if err != nil { return nil, fmt.Errorf("failed to get BigQuery session: %w", err) } @@ -349,10 +333,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go b/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go index d570eaf327..b083c49e2c 100644 --- a/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go +++ b/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go @@ -54,11 +54,6 @@ type compatibleSource interface { BigQueryAllowedDatasets() []string } -// validate compatible sources are still compatible -var _ compatibleSource = &bigqueryds.Source{} - -var compatibleSources = [...]string{bigqueryds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -84,7 +79,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } defaultProjectID := s.BigQueryProject() @@ -104,14 +99,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - ClientCreator: s.BigQueryClientCreator(), - Client: s.BigQueryClient(), - IsDatasetAllowed: s.IsDatasetAllowed, - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -121,15 +112,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Client *bigqueryapi.Client - ClientCreator bigqueryds.BigqueryClientCreator - Statement string - IsDatasetAllowed func(projectID, datasetID string) bool - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -137,6 +122,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() projectId, ok := mapParams[projectKey].(string) if !ok { @@ -148,22 +138,21 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey) } - bqClient := t.Client - var err error + bqClient := source.BigQueryClient() // Initialize new client if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - bqClient, _, err = t.ClientCreator(tokenStr, false) + bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false) if err != nil { return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) } } - if !t.IsDatasetAllowed(projectId, datasetId) { + if !source.IsDatasetAllowed(projectId, datasetId) { return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId) } @@ -193,10 +182,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigquery/bigquerygettableinfo/bigquerygettableinfo.go b/internal/tools/bigquery/bigquerygettableinfo/bigquerygettableinfo.go index c6174e4199..b896244ed0 100644 --- a/internal/tools/bigquery/bigquerygettableinfo/bigquerygettableinfo.go +++ b/internal/tools/bigquery/bigquerygettableinfo/bigquerygettableinfo.go @@ -55,11 +55,6 @@ type compatibleSource interface { BigQueryAllowedDatasets() []string } -// validate compatible sources are still compatible -var _ compatibleSource = &bigqueryds.Source{} - -var compatibleSources = [...]string{bigqueryds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -85,7 +80,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } defaultProjectID := s.BigQueryProject() @@ -108,14 +103,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - ClientCreator: s.BigQueryClientCreator(), - Client: s.BigQueryClient(), - IsDatasetAllowed: s.IsDatasetAllowed, - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -125,15 +116,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Client *bigqueryapi.Client - ClientCreator bigqueryds.BigqueryClientCreator - Statement string - IsDatasetAllowed func(projectID, datasetID string) bool - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -141,6 +126,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() projectId, ok := mapParams[projectKey].(string) if !ok { @@ -157,20 +147,19 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", tableKey) } - if !t.IsDatasetAllowed(projectId, datasetId) { + if !source.IsDatasetAllowed(projectId, datasetId) { return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId) } - bqClient := t.Client + bqClient := source.BigQueryClient() - var err error // Initialize new client if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - bqClient, _, err = t.ClientCreator(tokenStr, false) + bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false) if err != nil { return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) } @@ -203,10 +192,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigquery/bigquerylistdatasetids/bigquerylistdatasetids.go b/internal/tools/bigquery/bigquerylistdatasetids/bigquerylistdatasetids.go index 99484a3c20..dafe9b2246 100644 --- a/internal/tools/bigquery/bigquerylistdatasetids/bigquerylistdatasetids.go +++ b/internal/tools/bigquery/bigquerylistdatasetids/bigquerylistdatasetids.go @@ -52,11 +52,6 @@ type compatibleSource interface { BigQueryAllowedDatasets() []string } -// validate compatible sources are still compatible -var _ compatibleSource = &bigqueryds.Source{} - -var compatibleSources = [...]string{bigqueryds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -82,7 +77,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } var projectParameter parameters.Parameter @@ -103,14 +98,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - ClientCreator: s.BigQueryClientCreator(), - Client: s.BigQueryClient(), - AllowedDatasets: allowedDatasets, - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -120,15 +111,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Client *bigqueryapi.Client - ClientCreator bigqueryds.BigqueryClientCreator - Statement string - AllowedDatasets []string - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -136,8 +121,13 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - if len(t.AllowedDatasets) > 0 { - return t.AllowedDatasets, nil + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + if len(source.BigQueryAllowedDatasets()) > 0 { + return source.BigQueryAllowedDatasets(), nil } mapParams := params.AsMap() projectId, ok := mapParams[projectKey].(string) @@ -145,14 +135,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey) } - bqClient := t.Client + bqClient := source.BigQueryClient() // Initialize new client if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - bqClient, _, err = t.ClientCreator(tokenStr, false) + bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false) if err != nil { return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) } @@ -197,10 +187,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go b/internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go index d02a550304..11987c6dac 100644 --- a/internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go +++ b/internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go @@ -55,11 +55,6 @@ type compatibleSource interface { BigQueryAllowedDatasets() []string } -// validate compatible sources are still compatible -var _ compatibleSource = &bigqueryds.Source{} - -var compatibleSources = [...]string{bigqueryds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -85,7 +80,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } defaultProjectID := s.BigQueryProject() @@ -107,14 +102,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - ClientCreator: s.BigQueryClientCreator(), - Client: s.BigQueryClient(), - IsDatasetAllowed: s.IsDatasetAllowed, - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -124,15 +115,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Client *bigqueryapi.Client - ClientCreator bigqueryds.BigqueryClientCreator - IsDatasetAllowed func(projectID, datasetID string) bool - Statement string - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -140,6 +125,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() projectId, ok := mapParams[projectKey].(string) if !ok { @@ -151,18 +141,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey) } - if !t.IsDatasetAllowed(projectId, datasetId) { + if !source.IsDatasetAllowed(projectId, datasetId) { return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId) } - bqClient := t.Client + bqClient := source.BigQueryClient() // Initialize new client if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - bqClient, _, err = t.ClientCreator(tokenStr, false) + bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false) if err != nil { return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) } @@ -208,10 +198,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigquery/bigquerysearchcatalog/bigquerysearchcatalog.go b/internal/tools/bigquery/bigquerysearchcatalog/bigquerysearchcatalog.go index 0c53a7be6d..e134e9f298 100644 --- a/internal/tools/bigquery/bigquerysearchcatalog/bigquerysearchcatalog.go +++ b/internal/tools/bigquery/bigquerysearchcatalog/bigquerysearchcatalog.go @@ -51,11 +51,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &bigqueryds.Source{} - -var compatibleSources = [...]string{bigqueryds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -72,20 +67,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // Initialize the search configuration with the provided sources - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - - // Get the Dataplex client using the method from the source - makeCatalogClient := s.MakeDataplexCatalogClient() - prompt := parameters.NewStringParameter("prompt", "Prompt representing search intention. Do not rewrite the prompt.") datasetIds := parameters.NewArrayParameterWithDefault("datasetIds", []any{}, "Array of dataset IDs.", parameters.NewStringParameter("datasetId", "The IDs of the bigquery dataset.")) projectIds := parameters.NewArrayParameterWithDefault("projectIds", []any{}, "Array of project IDs.", parameters.NewStringParameter("projectId", "The IDs of the bigquery project.")) @@ -100,11 +81,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, params, nil) t := Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - MakeCatalogClient: makeCatalogClient, - ProjectID: s.BigQueryProject(), + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -117,12 +95,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) type Tool struct { Config - Parameters parameters.Parameters - UseClientOAuth bool - MakeCatalogClient func() (*dataplexapi.CatalogClient, bigqueryds.DataplexClientCreator, error) - ProjectID string - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -133,8 +108,12 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } func constructSearchQueryHelper(predicate string, operator string, items []string) string { @@ -207,6 +186,11 @@ func ExtractType(resourceString string) string { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() pageSize := int32(paramsMap["pageSize"].(int)) prompt, _ := paramsMap["prompt"].(string) @@ -228,14 +212,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para req := &dataplexpb.SearchEntriesRequest{ Query: fmt.Sprintf("%s %s", prompt, constructSearchQuery(projectIds, datasetIds, types)), - Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID), + Name: fmt.Sprintf("projects/%s/locations/global", source.BigQueryProject()), PageSize: pageSize, SemanticSearch: true, } - catalogClient, dataplexClientCreator, _ := t.MakeCatalogClient() + catalogClient, dataplexClientCreator, _ := source.MakeDataplexCatalogClient()() - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) @@ -248,7 +232,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para it := catalogClient.SearchEntries(ctx, req) if it == nil { - return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID) + return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.BigQueryProject()) } var results []Response @@ -288,6 +272,6 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigquery/bigquerysql/bigquerysql.go b/internal/tools/bigquery/bigquerysql/bigquerysql.go index 5e941deb7a..fa02f658eb 100644 --- a/internal/tools/bigquery/bigquerysql/bigquerysql.go +++ b/internal/tools/bigquery/bigquerysql/bigquerysql.go @@ -57,11 +57,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &bigqueryds.Source{} - -var compatibleSources = [...]string{bigqueryds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -81,18 +76,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -102,15 +85,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - AllParams: allParameters, - UseClientOAuth: s.UseClientAuthorization(), - Client: s.BigQueryClient(), - RestService: s.BigQueryRestService(), - SessionProvider: s.BigQuerySession(), - ClientCreator: s.BigQueryClientCreator(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + AllParams: allParameters, + manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -120,15 +98,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - AllParams parameters.Parameters `yaml:"allParams"` - - Client *bigqueryapi.Client - RestService *bigqueryrestapi.Service - SessionProvider bigqueryds.BigQuerySessionProvider - ClientCreator bigqueryds.BigqueryClientCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + AllParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -136,6 +108,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters)) lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters)) @@ -212,16 +189,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para lowLevelParams = append(lowLevelParams, lowLevelParam) } - bqClient := t.Client - restService := t.RestService + bqClient := source.BigQueryClient() + restService := source.BigQueryRestService() // Initialize new client if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - bqClient, restService, err = t.ClientCreator(tokenStr, true) + bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true) if err != nil { return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) } @@ -232,8 +209,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para query.Location = bqClient.Location connProps := []*bigqueryapi.ConnectionProperty{} - if t.SessionProvider != nil { - session, err := t.SessionProvider(ctx) + if source.BigQuerySession() != nil { + session, err := source.BigQuerySession()(ctx) if err != nil { return nil, fmt.Errorf("failed to get BigQuery session: %w", err) } @@ -311,10 +288,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigtable/bigtable.go b/internal/tools/bigtable/bigtable.go index 3f63994815..fe93630f95 100644 --- a/internal/tools/bigtable/bigtable.go +++ b/internal/tools/bigtable/bigtable.go @@ -21,7 +21,6 @@ import ( "cloud.google.com/go/bigtable" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - bigtabledb "github.com/googleapis/genai-toolbox/internal/sources/bigtable" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -46,11 +45,6 @@ type compatibleSource interface { BigtableClient() *bigtable.Client } -// validate compatible sources are still compatible -var _ compatibleSource = &bigtabledb.Source{} - -var compatibleSources = [...]string{bigtabledb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -70,18 +64,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -93,7 +75,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Client: s.BigtableClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -105,9 +86,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Client *bigtable.Client + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -156,6 +135,11 @@ func getMapParamsType(tparams parameters.Parameters, params parameters.ParamValu } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -172,7 +156,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("fail to get map params: %w", err) } - ps, err := t.Client.PrepareStatement( + ps, err := source.BigtableClient().PrepareStatement( ctx, newStatement, mapParamsType, @@ -224,10 +208,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cassandra/cassandracql/cassandracql.go b/internal/tools/cassandra/cassandracql/cassandracql.go index b650e3ba97..a05d0815ba 100644 --- a/internal/tools/cassandra/cassandracql/cassandracql.go +++ b/internal/tools/cassandra/cassandracql/cassandracql.go @@ -21,7 +21,6 @@ import ( gocql "github.com/apache/cassandra-gocql-driver/v2" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cassandra" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -46,10 +45,6 @@ type compatibleSource interface { CassandraSession() *gocql.Session } -var _ compatibleSource = &cassandra.Source{} - -var compatibleSources = [...]string{cassandra.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -61,20 +56,15 @@ type Config struct { TemplateParameters parameters.Parameters `yaml:"templateParameters"` } +var _ tools.ToolConfig = Config{} + +// ToolConfigKind implements tools.ToolConfig. +func (c Config) ToolConfigKind() string { + return kind +} + // Initialize implements tools.ToolConfig. func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[c.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", c.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(c.TemplateParameters, c.Parameters) if err != nil { return nil, err @@ -85,25 +75,17 @@ func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { t := Tool{ Config: c, AllParams: allParameters, - Session: s.CassandraSession(), manifest: tools.Manifest{Description: c.Description, Parameters: paramManifest, AuthRequired: c.AuthRequired}, mcpManifest: mcpManifest, } return t, nil } -// ToolConfigKind implements tools.ToolConfig. -func (c Config) ToolConfigKind() string { - return kind -} - -var _ tools.ToolConfig = Config{} +var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Session *gocql.Session + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -113,8 +95,8 @@ func (t Tool) ToConfig() tools.ToolConfig { } // RequiresClientAuthorization implements tools.Tool. -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } // Authorized implements tools.Tool. @@ -124,6 +106,11 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { // Invoke implements tools.Tool. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -135,7 +122,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - iter := t.Session.Query(newStatement, sliceParams...).IterContext(ctx) + iter := source.CassandraSession().Query(newStatement, sliceParams...).IterContext(ctx) // Create a slice to store the out var out []map[string]interface{} @@ -170,8 +157,6 @@ func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) return parameters.ParseParams(t.AllParams, data, claims) } -var _ tools.Tool = Tool{} - -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/clickhouse/clickhouseexecutesql/clickhouseexecutesql.go b/internal/tools/clickhouse/clickhouseexecutesql/clickhouseexecutesql.go index 4e5e0448ee..826d20d482 100644 --- a/internal/tools/clickhouse/clickhouseexecutesql/clickhouseexecutesql.go +++ b/internal/tools/clickhouse/clickhouseexecutesql/clickhouseexecutesql.go @@ -25,12 +25,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/util/parameters" ) -type compatibleSource interface { - ClickHousePool() *sql.DB -} - -var compatibleSources = []string{"clickhouse"} - const executeSQLKind string = "clickhouse-execute-sql" func init() { @@ -47,6 +41,10 @@ func newExecuteSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder return actual, nil } +type compatibleSource interface { + ClickHousePool() *sql.DB +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,16 +60,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", executeSQLKind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The SQL statement to execute.") params := parameters.Parameters{sqlParameter} @@ -80,7 +68,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Pool: s.ClickHousePool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -91,9 +78,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Pool *sql.DB + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -103,13 +88,18 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { return nil, fmt.Errorf("unable to cast sql parameter %s", paramsMap["sql"]) } - results, err := t.Pool.QueryContext(ctx, sql) + results, err := source.ClickHousePool().QueryContext(ctx, sql) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -183,10 +173,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases.go b/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases.go index 9015e511cb..e6df548907 100644 --- a/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases.go +++ b/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases.go @@ -25,12 +25,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/util/parameters" ) -type compatibleSource interface { - ClickHousePool() *sql.DB -} - -var compatibleSources = []string{"clickhouse"} - const listDatabasesKind string = "clickhouse-list-databases" func init() { @@ -47,6 +41,10 @@ func newListDatabasesConfig(ctx context.Context, name string, decoder *yaml.Deco return actual, nil } +type compatibleSource interface { + ClickHousePool() *sql.DB +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -63,23 +61,12 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", listDatabasesKind, compatibleSources) - } - allParameters, paramManifest, _ := parameters.ProcessParameters(nil, cfg.Parameters) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.ClickHousePool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -90,9 +77,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Pool *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -102,10 +87,15 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + // Query to list all databases query := "SHOW DATABASES" - results, err := t.Pool.QueryContext(ctx, query) + results, err := source.ClickHousePool().QueryContext(ctx, query) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -146,10 +136,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases_test.go b/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases_test.go index 768b942b41..ca6d9b21b7 100644 --- a/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases_test.go +++ b/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases_test.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/google/go-cmp/cmp" "github.com/googleapis/genai-toolbox/internal/server" - "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/testutils" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -32,21 +31,6 @@ func TestListDatabasesConfigToolConfigKind(t *testing.T) { } } -func TestListDatabasesConfigInitializeMissingSource(t *testing.T) { - cfg := Config{ - Name: "test-list-databases", - Kind: listDatabasesKind, - Source: "missing-source", - Description: "Test list databases tool", - } - - srcs := map[string]sources.Source{} - _, err := cfg.Initialize(srcs) - if err == nil { - t.Error("expected error for missing source") - } -} - func TestParseFromYamlClickHouseListDatabases(t *testing.T) { ctx, err := testutils.ContextWithNewLogger() if err != nil { diff --git a/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables.go b/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables.go index 16a3b45911..e882a88ea5 100644 --- a/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables.go +++ b/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables.go @@ -25,12 +25,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/util/parameters" ) -type compatibleSource interface { - ClickHousePool() *sql.DB -} - -var compatibleSources = []string{"clickhouse"} - const listTablesKind string = "clickhouse-list-tables" const databaseKey string = "database" @@ -48,6 +42,10 @@ func newListTablesConfig(ctx context.Context, name string, decoder *yaml.Decoder return actual, nil } +type compatibleSource interface { + ClickHousePool() *sql.DB +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -64,16 +62,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", listTablesKind, compatibleSources) - } - databaseParameter := parameters.NewStringParameter(databaseKey, "The database to list tables from.") params := parameters.Parameters{databaseParameter} @@ -83,7 +71,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.ClickHousePool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -94,9 +81,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Pool *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -106,6 +91,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() database, ok := mapParams[databaseKey].(string) if !ok { @@ -115,7 +105,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Query to list all tables in the specified database query := fmt.Sprintf("SHOW TABLES FROM %s", database) - results, err := t.Pool.QueryContext(ctx, query) + results, err := source.ClickHousePool().QueryContext(ctx, query) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -157,10 +147,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables_test.go b/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables_test.go index 2705ded3fc..4500dac099 100644 --- a/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables_test.go +++ b/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables_test.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/google/go-cmp/cmp" "github.com/googleapis/genai-toolbox/internal/server" - "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/testutils" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -32,21 +31,6 @@ func TestListTablesConfigToolConfigKind(t *testing.T) { } } -func TestListTablesConfigInitializeMissingSource(t *testing.T) { - cfg := Config{ - Name: "test-list-tables", - Kind: listTablesKind, - Source: "missing-source", - Description: "Test list tables tool", - } - - srcs := map[string]sources.Source{} - _, err := cfg.Initialize(srcs) - if err == nil { - t.Error("expected error for missing source") - } -} - func TestParseFromYamlClickHouseListTables(t *testing.T) { ctx, err := testutils.ContextWithNewLogger() if err != nil { diff --git a/internal/tools/clickhouse/clickhousesql/clickhousesql.go b/internal/tools/clickhouse/clickhousesql/clickhousesql.go index 6dade66701..83a2f1ee9d 100644 --- a/internal/tools/clickhouse/clickhousesql/clickhousesql.go +++ b/internal/tools/clickhouse/clickhousesql/clickhousesql.go @@ -25,21 +25,15 @@ import ( "github.com/googleapis/genai-toolbox/internal/util/parameters" ) -type compatibleSource interface { - ClickHousePool() *sql.DB -} - -var compatibleSources = []string{"clickhouse"} - const sqlKind string = "clickhouse-sql" func init() { - if !tools.Register(sqlKind, newSQLConfig) { + if !tools.Register(sqlKind, newConfig) { panic(fmt.Sprintf("tool kind %q already registered", sqlKind)) } } -func newSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { actual := Config{Name: name} if err := decoder.DecodeContext(ctx, &actual); err != nil { return nil, err @@ -47,6 +41,10 @@ func newSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tool return actual, nil } +type compatibleSource interface { + ClickHousePool() *sql.DB +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -65,23 +63,12 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", sqlKind, compatibleSources) - } - allParameters, paramManifest, _ := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.ClickHousePool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -93,7 +80,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config AllParams parameters.Parameters `yaml:"allParams"` - Pool *sql.DB manifest tools.Manifest mcpManifest tools.McpManifest } @@ -103,6 +89,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -115,7 +106,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.Pool.QueryContext(ctx, newStatement, sliceParams...) + results, err := source.ClickHousePool().QueryContext(ctx, newStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -191,10 +182,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/clickhouse/clickhousesql/clickhousesql_test.go b/internal/tools/clickhouse/clickhousesql/clickhousesql_test.go index 4c127bd734..3c50305e28 100644 --- a/internal/tools/clickhouse/clickhousesql/clickhousesql_test.go +++ b/internal/tools/clickhouse/clickhousesql/clickhousesql_test.go @@ -142,66 +142,6 @@ func TestSQLConfigInitializeValidSource(t *testing.T) { } } -func TestSQLConfigInitializeMissingSource(t *testing.T) { - config := Config{ - Name: "test-tool", - Kind: sqlKind, - Source: "missing-source", - Description: "Test tool", - Statement: "SELECT 1", - Parameters: parameters.Parameters{}, - } - - sources := map[string]sources.Source{} - - _, err := config.Initialize(sources) - if err == nil { - t.Fatal("Expected error for missing source, got nil") - } - - expectedErr := `no source named "missing-source" configured` - if err.Error() != expectedErr { - t.Errorf("Expected error %q, got %q", expectedErr, err.Error()) - } -} - -// mockIncompatibleSource is a mock source that doesn't implement the compatibleSource interface -type mockIncompatibleSource struct{} - -func (m *mockIncompatibleSource) SourceKind() string { - return "mock" -} - -func (m *mockIncompatibleSource) ToConfig() sources.SourceConfig { - return nil -} - -func TestSQLConfigInitializeIncompatibleSource(t *testing.T) { - config := Config{ - Name: "test-tool", - Kind: sqlKind, - Source: "incompatible-source", - Description: "Test tool", - Statement: "SELECT 1", - Parameters: parameters.Parameters{}, - } - - mockSource := &mockIncompatibleSource{} - - sources := map[string]sources.Source{ - "incompatible-source": mockSource, - } - - _, err := config.Initialize(sources) - if err == nil { - t.Fatal("Expected error for incompatible source, got nil") - } - - if err.Error() == "" { - t.Error("Expected non-empty error message") - } -} - func TestToolManifest(t *testing.T) { tool := Tool{ manifest: tools.Manifest{ diff --git a/internal/tools/cloudgda/cloudgda.go b/internal/tools/cloudgda/cloudgda.go new file mode 100644 index 0000000000..bf54c26c3f --- /dev/null +++ b/internal/tools/cloudgda/cloudgda.go @@ -0,0 +1,206 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudgda + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util/parameters" +) + +const kind string = "cloud-gemini-data-analytics-query" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + GetProjectID() string + GetBaseURL() string + UseClientAuthorization() bool + GetClient(context.Context, string) (*http.Client, error) +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + Location string `yaml:"location" validate:"required"` + Context *QueryDataContext `yaml:"context" validate:"required"` + GenerationOptions *GenerationOptions `yaml:"generationOptions,omitempty"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + // Define the parameters for the Gemini Data Analytics Query API + // The prompt is the only input parameter. + allParameters := parameters.Parameters{ + parameters.NewStringParameterWithRequired("prompt", "The natural language question to ask.", true), + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) + + return Tool{ + Config: cfg, + AllParams: allParameters, + manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + }, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + AllParams parameters.Parameters + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +// Invoke executes the tool logic +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + paramsMap := params.AsMap() + prompt, ok := paramsMap["prompt"].(string) + if !ok { + return nil, fmt.Errorf("prompt parameter not found or not a string") + } + + // The API endpoint itself always uses the "global" location. + apiLocation := "global" + apiParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), apiLocation) + apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", source.GetBaseURL(), apiParent) + + // The parent in the request payload uses the tool's configured location. + payloadParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), t.Location) + + payload := &QueryDataRequest{ + Parent: payloadParent, + Prompt: prompt, + Context: t.Context, + GenerationOptions: t.GenerationOptions, + } + + bodyBytes, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal request payload: %w", err) + } + + // Parse the access token if provided + var tokenStr string + if source.UseClientAuthorization() { + var err error + tokenStr, err = accessToken.ParseBearerToken() + if err != nil { + return nil, fmt.Errorf("error parsing access token: %w", err) + } + } + + client, err := source.GetClient(ctx, tokenStr) + if err != nil { + return nil, fmt.Errorf("failed to get HTTP client: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to execute request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody)) + } + + var result map[string]any + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return result, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.AllParams, data, claims) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + +func (t Tool) GetAuthTokenHeaderName(_ tools.SourceProvider) (string, error) { + return "Authorization", nil +} diff --git a/internal/tools/cloudgda/cloudgda_test.go b/internal/tools/cloudgda/cloudgda_test.go new file mode 100644 index 0000000000..0d57032904 --- /dev/null +++ b/internal/tools/cloudgda/cloudgda_test.go @@ -0,0 +1,353 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudgda_test + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/server/resources" + "github.com/googleapis/genai-toolbox/internal/sources" + cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools" + cloudgdatool "github.com/googleapis/genai-toolbox/internal/tools/cloudgda" + "github.com/googleapis/genai-toolbox/internal/util/parameters" +) + +func TestParseFromYaml(t *testing.T) { + t.Parallel() + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + my-gda-query-tool: + kind: cloud-gemini-data-analytics-query + source: gda-api-source + description: Test Description + location: us-central1 + context: + datasourceReferences: + spannerReference: + databaseReference: + projectId: "cloud-db-nl2sql" + region: "us-central1" + instanceId: "evalbench" + databaseId: "financial" + engine: "GOOGLE_SQL" + agentContextReference: + contextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates" + generationOptions: + generateQueryResult: true + `, + want: map[string]tools.ToolConfig{ + "my-gda-query-tool": cloudgdatool.Config{ + Name: "my-gda-query-tool", + Kind: "cloud-gemini-data-analytics-query", + Source: "gda-api-source", + Description: "Test Description", + Location: "us-central1", + AuthRequired: []string{}, + Context: &cloudgdatool.QueryDataContext{ + DatasourceReferences: &cloudgdatool.DatasourceReferences{ + SpannerReference: &cloudgdatool.SpannerReference{ + DatabaseReference: &cloudgdatool.SpannerDatabaseReference{ + ProjectID: "cloud-db-nl2sql", + Region: "us-central1", + InstanceID: "evalbench", + DatabaseID: "financial", + Engine: cloudgdatool.SpannerEngineGoogleSQL, + }, + AgentContextReference: &cloudgdatool.AgentContextReference{ + ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates", + }, + }, + }, + }, + GenerationOptions: &cloudgdatool.GenerationOptions{ + GenerateQueryResult: true, + }, + }, + }, + }, + } + for _, tc := range tcs { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if !cmp.Equal(tc.want, got.Tools) { + t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Tools) + } + }) + } +} + +// authRoundTripper is a mock http.RoundTripper that adds a dummy Authorization header. +type authRoundTripper struct { + Token string + Next http.RoundTripper +} + +func (rt *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + newReq := *req + newReq.Header = make(http.Header) + for k, v := range req.Header { + newReq.Header[k] = v + } + newReq.Header.Set("Authorization", rt.Token) + if rt.Next == nil { + return http.DefaultTransport.RoundTrip(&newReq) + } + return rt.Next.RoundTrip(&newReq) +} + +type mockSource struct { + kind string + client *http.Client // Can be used to inject a specific client + baseURL string // BaseURL is needed to implement sources.Source.BaseURL + config cloudgdasrc.Config // to return from ToConfig +} + +func (m *mockSource) SourceKind() string { return m.kind } +func (m *mockSource) ToConfig() sources.SourceConfig { return m.config } +func (m *mockSource) GetClient(ctx context.Context, token string) (*http.Client, error) { + if m.client != nil { + return m.client, nil + } + // Default client for testing if not explicitly set + transport := &http.Transport{} + authTransport := &authRoundTripper{ + Token: "Bearer test-access-token", // Dummy token + Next: transport, + } + return &http.Client{Transport: authTransport}, nil +} +func (m *mockSource) UseClientAuthorization() bool { return false } +func (m *mockSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) { + return m, nil +} +func (m *mockSource) BaseURL() string { return m.baseURL } + +func TestInitialize(t *testing.T) { + t.Parallel() + + srcs := map[string]sources.Source{ + "gda-api-source": &cloudgdasrc.Source{ + Config: cloudgdasrc.Config{Name: "gda-api-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"}, + Client: &http.Client{}, + BaseURL: cloudgdasrc.Endpoint, + }, + } + + tcs := []struct { + desc string + cfg cloudgdatool.Config + }{ + { + desc: "successful initialization", + cfg: cloudgdatool.Config{ + Name: "my-gda-query-tool", + Kind: "cloud-gemini-data-analytics-query", + Source: "gda-api-source", + Description: "Test Description", + Location: "us-central1", + }, + }, + } + + // Add an incompatible source for testing + srcs["incompatible-source"] = &mockSource{kind: "another-kind"} + + for _, tc := range tcs { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + tool, err := tc.cfg.Initialize(srcs) + if err != nil { + t.Fatalf("did not expect an error but got: %v", err) + } + // Basic sanity check on the returned tool + _ = tool // Avoid unused variable error + }) + } +} + +func TestInvoke(t *testing.T) { + t.Parallel() + // Mock the HTTP client and server for Invoke testing + serverMux := http.NewServeMux() + // Update expected URL path to include the location "us-central1" + serverMux.HandleFunc("/v1beta/projects/test-project/locations/global:queryData", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST method, got %s", r.Method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + // Read and unmarshal the request body + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("failed to read request body: %v", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + var reqPayload cloudgdatool.QueryDataRequest + if err := json.Unmarshal(bodyBytes, &reqPayload); err != nil { + t.Errorf("failed to unmarshal request payload: %v", err) + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + // Verify expected fields + if r.Header.Get("Authorization") == "" { + t.Errorf("expected Authorization header, got empty") + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + if reqPayload.Prompt != "How many accounts who have region in Prague are eligible for loans?" { + t.Errorf("unexpected prompt: %s", reqPayload.Prompt) + } + + // Verify payload's parent uses the tool's configured location + if reqPayload.Parent != fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1") { + t.Errorf("unexpected payload parent: got %q, want %q", reqPayload.Parent, fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1")) + } + + // Verify context from config + if reqPayload.Context == nil || + reqPayload.Context.DatasourceReferences == nil || + reqPayload.Context.DatasourceReferences.SpannerReference == nil || + reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference == nil || + reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference.ProjectID != "cloud-db-nl2sql" { + t.Errorf("unexpected context: %v", reqPayload.Context) + } + + // Verify generation options from config + if reqPayload.GenerationOptions == nil || !reqPayload.GenerationOptions.GenerateQueryResult { + t.Errorf("unexpected generation options: %v", reqPayload.GenerationOptions) + } + + // Simulate a successful response + resp := map[string]any{ + "queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;", + "naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.", + } + _ = json.NewEncoder(w).Encode(resp) + }) + + mockServer := httptest.NewServer(serverMux) + defer mockServer.Close() + + ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent") + + // Create an authenticated client that uses the mock server + authTransport := &authRoundTripper{ + Token: "Bearer test-access-token", + Next: mockServer.Client().Transport, + } + authClient := &http.Client{Transport: authTransport} + + // Create a real cloudgdasrc.Source but inject the authenticated client + mockGdaSource := &cloudgdasrc.Source{ + Config: cloudgdasrc.Config{Name: "mock-gda-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"}, + Client: authClient, + BaseURL: mockServer.URL, + } + srcs := map[string]sources.Source{ + "mock-gda-source": mockGdaSource, + } + + // Initialize the tool config with context + toolCfg := cloudgdatool.Config{ + Name: "query-data-tool", + Kind: "cloud-gemini-data-analytics-query", + Source: "mock-gda-source", + Description: "Query Gemini Data Analytics", + Location: "us-central1", // Set location for the test + Context: &cloudgdatool.QueryDataContext{ + DatasourceReferences: &cloudgdatool.DatasourceReferences{ + SpannerReference: &cloudgdatool.SpannerReference{ + DatabaseReference: &cloudgdatool.SpannerDatabaseReference{ + ProjectID: "cloud-db-nl2sql", + Region: "us-central1", + InstanceID: "evalbench", + DatabaseID: "financial", + Engine: cloudgdatool.SpannerEngineGoogleSQL, + }, + AgentContextReference: &cloudgdatool.AgentContextReference{ + ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates", + }, + }, + }, + }, + GenerationOptions: &cloudgdatool.GenerationOptions{ + GenerateQueryResult: true, + }, + } + + tool, err := toolCfg.Initialize(srcs) + if err != nil { + t.Fatalf("failed to initialize tool: %v", err) + } + + // Prepare parameters for invocation - ONLY prompt + params := parameters.ParamValues{ + {Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"}, + } + + resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil) + + // Invoke the tool + result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client + if err != nil { + t.Fatalf("tool invocation failed: %v", err) + } + + // Validate the result + expectedResult := map[string]any{ + "queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;", + "naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.", + } + + if !cmp.Equal(expectedResult, result) { + t.Errorf("unexpected result: got %v, want %v", result, expectedResult) + } +} diff --git a/internal/tools/cloudgda/types.go b/internal/tools/cloudgda/types.go new file mode 100644 index 0000000000..8e82cb50c2 --- /dev/null +++ b/internal/tools/cloudgda/types.go @@ -0,0 +1,116 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudgda + +// See full service definition at: https://github.com/googleapis/googleapis/blob/master/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto + +// QueryDataRequest represents the JSON body for the queryData API +type QueryDataRequest struct { + Parent string `json:"parent"` + Prompt string `json:"prompt"` + Context *QueryDataContext `json:"context,omitempty"` + GenerationOptions *GenerationOptions `json:"generationOptions,omitempty"` +} + +// QueryDataContext reflects the proto definition for the query context. +type QueryDataContext struct { + DatasourceReferences *DatasourceReferences `json:"datasourceReferences,omitempty" yaml:"datasourceReferences,omitempty"` +} + +// DatasourceReferences reflects the proto definition for datasource references, using a oneof. +type DatasourceReferences struct { + SpannerReference *SpannerReference `json:"spannerReference,omitempty" yaml:"spannerReference,omitempty"` + AlloyDBReference *AlloyDBReference `json:"alloydb,omitempty" yaml:"alloydb,omitempty"` + CloudSQLReference *CloudSQLReference `json:"cloudSqlReference,omitempty" yaml:"cloudSqlReference,omitempty"` +} + +// SpannerReference reflects the proto definition for Spanner database reference. +type SpannerReference struct { + DatabaseReference *SpannerDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"` + AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"` +} + +// SpannerDatabaseReference reflects the proto definition for a Spanner database reference. +type SpannerDatabaseReference struct { + Engine SpannerEngine `json:"engine,omitempty" yaml:"engine,omitempty"` + ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"` + Region string `json:"region,omitempty" yaml:"region,omitempty"` + InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"` + DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"` + TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"` +} + +// SpannerEngine represents the engine of the Spanner instance. +type SpannerEngine string + +const ( + SpannerEngineUnspecified SpannerEngine = "ENGINE_UNSPECIFIED" + SpannerEngineGoogleSQL SpannerEngine = "GOOGLE_SQL" + SpannerEnginePostgreSQL SpannerEngine = "POSTGRESQL" +) + +// AlloyDBReference reflects the proto definition for an AlloyDB database reference. +type AlloyDBReference struct { + DatabaseReference *AlloyDBDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"` + AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"` +} + +// AlloyDBDatabaseReference reflects the proto definition for an AlloyDB database reference. +type AlloyDBDatabaseReference struct { + ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"` + Region string `json:"region,omitempty" yaml:"region,omitempty"` + ClusterID string `json:"clusterId,omitempty" yaml:"clusterId,omitempty"` + InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"` + DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"` + TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"` +} + +// CloudSQLReference reflects the proto definition for a Cloud SQL database reference. +type CloudSQLReference struct { + DatabaseReference *CloudSQLDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"` + AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"` +} + +// CloudSQLDatabaseReference reflects the proto definition for a Cloud SQL database reference. +type CloudSQLDatabaseReference struct { + Engine CloudSQLEngine `json:"engine,omitempty" yaml:"engine,omitempty"` + ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"` + Region string `json:"region,omitempty" yaml:"region,omitempty"` + InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"` + DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"` + TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"` +} + +// CloudSQLEngine represents the engine of the Cloud SQL instance. +type CloudSQLEngine string + +const ( + CloudSQLEngineUnspecified CloudSQLEngine = "ENGINE_UNSPECIFIED" + CloudSQLEnginePostgreSQL CloudSQLEngine = "POSTGRESQL" + CloudSQLEngineMySQL CloudSQLEngine = "MYSQL" +) + +// AgentContextReference reflects the proto definition for agent context. +type AgentContextReference struct { + ContextSetID string `json:"contextSetId,omitempty" yaml:"contextSetId,omitempty"` +} + +// GenerationOptions reflects the proto definition for generation options. +type GenerationOptions struct { + GenerateQueryResult bool `json:"generateQueryResult" yaml:"generateQueryResult"` + GenerateNaturalLanguageAnswer bool `json:"generateNaturalLanguageAnswer" yaml:"generateNaturalLanguageAnswer"` + GenerateExplanation bool `json:"generateExplanation" yaml:"generateExplanation"` + GenerateDisambiguationQuestion bool `json:"generateDisambiguationQuestion" yaml:"generateDisambiguationQuestion"` +} diff --git a/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage/cloudhealthcarefhirfetchpage.go b/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage/cloudhealthcarefhirfetchpage.go index 5a4c22c471..025ca9310f 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage/cloudhealthcarefhirfetchpage.go +++ b/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage/cloudhealthcarefhirfetchpage.go @@ -62,11 +62,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -83,35 +78,16 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - urlParameter := parameters.NewStringParameter(pageURLKey, "The full URL of the FHIR page to fetch. This would be the value of `Bundle.entry.link.url` field within the response returned from FHIR search or FHIR patient everything operations.") params := parameters.Parameters{urlParameter} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedFHIRStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -121,14 +97,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -136,13 +107,18 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + url, ok := params.AsMap()[pageURLKey].(string) if !ok { return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey) } var httpClient *http.Client - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) @@ -150,7 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tokenStr}) httpClient = oauth2.NewClient(ctx, ts) } else { - // The t.Service object holds a client with the default credentials. + // The source.Service() object holds a client with the default credentials. // However, the client is not exported, so we have to create a new one. var err error httpClient, err = google.DefaultClient(ctx, healthcare.CloudHealthcareScope) @@ -201,10 +177,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything/cloudhealthcarefhirpatienteverything.go b/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything/cloudhealthcarefhirpatienteverything.go index 11745be090..b00d7c35ac 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything/cloudhealthcarefhirpatienteverything.go +++ b/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything/cloudhealthcarefhirpatienteverything.go @@ -62,11 +62,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -92,7 +87,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } idParameter := parameters.NewStringParameter(patientIDKey, "The ID of the patient FHIR resource for which the information is required") @@ -106,17 +101,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedFHIRStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -126,15 +114,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -142,7 +124,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores()) if err != nil { return nil, err } @@ -151,20 +138,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey) } - svc := t.Service + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } } - name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", t.Project, t.Region, t.Dataset, storeID, patientID) + name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", source.Project(), source.Region(), source.DatasetID(), storeID, patientID) var opts []googleapi.CallOption if val, ok := params.AsMap()[typeFilterKey]; ok { types, ok := val.([]any) @@ -225,10 +212,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch/cloudhealthcarefhirpatientsearch.go b/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch/cloudhealthcarefhirpatientsearch.go index 23acc98a4d..c1cf43b59f 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch/cloudhealthcarefhirpatientsearch.go +++ b/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch/cloudhealthcarefhirpatientsearch.go @@ -78,11 +78,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -108,7 +103,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } params := parameters.Parameters{ @@ -140,17 +135,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedFHIRStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -160,15 +148,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -176,19 +158,24 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { return nil, err } - svc := t.Service + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores()) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } @@ -261,7 +248,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para opts = append(opts, googleapi.QueryParameter("_summary", "text")) } - name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID) + name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID) resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.SearchType(name, "Patient", &healthcare.SearchResourcesRequest{ResourceType: "Patient"}).Do(opts...) if err != nil { return nil, fmt.Errorf("failed to search patient resources: %w", err) @@ -298,10 +285,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcaregetdataset/cloudhealthcaregetdataset.go b/internal/tools/cloudhealthcare/cloudhealthcaregetdataset/cloudhealthcaregetdataset.go index 4b07558300..d3386cb657 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaregetdataset/cloudhealthcaregetdataset.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaregetdataset/cloudhealthcaregetdataset.go @@ -51,11 +51,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -72,33 +67,15 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - params := parameters.Parameters{} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -108,13 +85,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - Project, Region, Dataset string - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -122,22 +95,26 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - svc := t.Service - var err error + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } } - datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset) + datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID()) dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do() if err != nil { return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err) @@ -161,10 +138,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstore/cloudhealthcaregetdicomstore.go b/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstore/cloudhealthcaregetdicomstore.go index ee81cf5fc3..d8da9c096e 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstore/cloudhealthcaregetdicomstore.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstore/cloudhealthcaregetdicomstore.go @@ -53,11 +53,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } params := parameters.Parameters{} @@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedDICOMStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -114,15 +102,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { return nil, err } - svc := t.Service + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores()) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } } - storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID) + storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID) store, err := svc.Projects.Locations.Datasets.DicomStores.Get(storeName).Do() if err != nil { return nil, fmt.Errorf("failed to get DICOM store %q: %w", storeName, err) @@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstoremetrics/cloudhealthcaregetdicomstoremetrics.go b/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstoremetrics/cloudhealthcaregetdicomstoremetrics.go index b6b9f5fbc9..03f73dd0a4 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstoremetrics/cloudhealthcaregetdicomstoremetrics.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstoremetrics/cloudhealthcaregetdicomstoremetrics.go @@ -53,11 +53,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } params := parameters.Parameters{} @@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedDICOMStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -114,15 +102,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { return nil, err } - svc := t.Service + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores()) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } } - storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID) + storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID) store, err := svc.Projects.Locations.Datasets.DicomStores.GetDICOMStoreMetrics(storeName).Do() if err != nil { return nil, fmt.Errorf("failed to get metrics for DICOM store %q: %w", storeName, err) @@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcaregetfhirresource/cloudhealthcaregetfhirresource.go b/internal/tools/cloudhealthcare/cloudhealthcaregetfhirresource/cloudhealthcaregetfhirresource.go index e8c00d78c1..41c4e71db2 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaregetfhirresource/cloudhealthcaregetfhirresource.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaregetfhirresource/cloudhealthcaregetfhirresource.go @@ -59,11 +59,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -89,7 +84,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } typeParameter := parameters.NewStringParameter(typeKey, "The FHIR resource type to retrieve (e.g., Patient, Observation).") @@ -102,17 +97,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedFHIRStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -122,15 +110,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -138,7 +120,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores()) if err != nil { return nil, err } @@ -152,20 +139,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey) } - svc := t.Service + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } } - name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", t.Project, t.Region, t.Dataset, storeID, resType, resID) + name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", source.Project(), source.Region(), source.DatasetID(), storeID, resType, resID) call := svc.Projects.Locations.Datasets.FhirStores.Fhir.Read(name) call.Header().Set("Content-Type", "application/fhir+json;charset=utf-8") resp, err := call.Do() @@ -204,10 +191,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstore/cloudhealthcaregetfhirstore.go b/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstore/cloudhealthcaregetfhirstore.go index 0a42c25190..1760579b35 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstore/cloudhealthcaregetfhirstore.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstore/cloudhealthcaregetfhirstore.go @@ -53,11 +53,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } params := parameters.Parameters{} @@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedFHIRStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -114,15 +102,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { return nil, err } - svc := t.Service + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores()) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } } - storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID) + storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID) store, err := svc.Projects.Locations.Datasets.FhirStores.Get(storeName).Do() if err != nil { return nil, fmt.Errorf("failed to get FHIR store %q: %w", storeName, err) @@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstoremetrics/cloudhealthcaregetfhirstoremetrics.go b/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstoremetrics/cloudhealthcaregetfhirstoremetrics.go index 7c1f60363d..29e1011da2 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstoremetrics/cloudhealthcaregetfhirstoremetrics.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstoremetrics/cloudhealthcaregetfhirstoremetrics.go @@ -53,11 +53,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } params := parameters.Parameters{} @@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedFHIRStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -114,15 +102,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { return nil, err } - svc := t.Service + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores()) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } } - storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID) + storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID) store, err := svc.Projects.Locations.Datasets.FhirStores.GetFHIRStoreMetrics(storeName).Do() if err != nil { return nil, fmt.Errorf("failed to get metrics for FHIR store %q: %w", storeName, err) @@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcarelistdicomstores/cloudhealthcarelistdicomstores.go b/internal/tools/cloudhealthcare/cloudhealthcarelistdicomstores/cloudhealthcarelistdicomstores.go index 8e25aa52f5..e180a8028f 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcarelistdicomstores/cloudhealthcarelistdicomstores.go +++ b/internal/tools/cloudhealthcare/cloudhealthcarelistdicomstores/cloudhealthcarelistdicomstores.go @@ -53,11 +53,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -74,34 +69,15 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - params := parameters.Parameters{} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedDICOMStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -111,15 +87,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -127,29 +97,33 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - svc := t.Service - var err error + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } } - datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset) + datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID()) stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do() if err != nil { return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err) } var filtered []*healthcare.DicomStore for _, store := range stores.DicomStores { - if len(t.AllowedStores) == 0 { + if len(source.AllowedDICOMStores()) == 0 { filtered = append(filtered, store) continue } @@ -157,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para continue } parts := strings.Split(store.Name, "/") - if _, ok := t.AllowedStores[parts[len(parts)-1]]; ok { + if _, ok := source.AllowedDICOMStores()[parts[len(parts)-1]]; ok { filtered = append(filtered, store) } } @@ -180,10 +154,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcarelistfhirstores/cloudhealthcarelistfhirstores.go b/internal/tools/cloudhealthcare/cloudhealthcarelistfhirstores/cloudhealthcarelistfhirstores.go index 287311b09a..5e9ea52359 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcarelistfhirstores/cloudhealthcarelistfhirstores.go +++ b/internal/tools/cloudhealthcare/cloudhealthcarelistfhirstores/cloudhealthcarelistfhirstores.go @@ -53,11 +53,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -74,34 +69,15 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - params := parameters.Parameters{} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedFHIRStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -111,15 +87,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -127,29 +97,33 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - svc := t.Service - var err error + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } } - datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset) + datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID()) stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do() if err != nil { return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err) } var filtered []*healthcare.FhirStore for _, store := range stores.FhirStores { - if len(t.AllowedStores) == 0 { + if len(source.AllowedFHIRStores()) == 0 { filtered = append(filtered, store) continue } @@ -157,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para continue } parts := strings.Split(store.Name, "/") - if _, ok := t.AllowedStores[parts[len(parts)-1]]; ok { + if _, ok := source.AllowedFHIRStores()[parts[len(parts)-1]]; ok { filtered = append(filtered, store) } } @@ -180,10 +154,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcareretrieverendereddicominstance/cloudhealthcareretrieverendereddicominstance.go b/internal/tools/cloudhealthcare/cloudhealthcareretrieverendereddicominstance/cloudhealthcareretrieverendereddicominstance.go index 076b3cae58..6272fda5df 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcareretrieverendereddicominstance/cloudhealthcareretrieverendereddicominstance.go +++ b/internal/tools/cloudhealthcare/cloudhealthcareretrieverendereddicominstance/cloudhealthcareretrieverendereddicominstance.go @@ -61,11 +61,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -91,7 +86,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } params := parameters.Parameters{ @@ -107,17 +102,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedDICOMStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -127,15 +115,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -143,19 +125,24 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { return nil, err } - svc := t.Service + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores()) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } @@ -177,7 +164,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if !ok { return nil, fmt.Errorf("invalid '%s' parameter; expected an integer", frameNumberKey) } - name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID) + name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID) dicomWebPath := fmt.Sprintf("studies/%s/series/%s/instances/%s/frames/%d/rendered", study, series, sop, frame) call := svc.Projects.Locations.Datasets.DicomStores.Studies.Series.Instances.Frames.RetrieveRendered(name, dicomWebPath) call.Header().Set("Accept", "image/jpeg") @@ -214,10 +201,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcaresearchdicominstances/cloudhealthcaresearchdicominstances.go b/internal/tools/cloudhealthcare/cloudhealthcaresearchdicominstances/cloudhealthcaresearchdicominstances.go index 50021bba41..afe0f4cc2e 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaresearchdicominstances/cloudhealthcaresearchdicominstances.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaresearchdicominstances/cloudhealthcaresearchdicominstances.go @@ -68,11 +68,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -98,7 +93,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } params := parameters.Parameters{ @@ -121,17 +116,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedDICOMStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -141,15 +129,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -157,19 +139,24 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { return nil, err } - svc := t.Service + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores()) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } @@ -204,7 +191,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID) + name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID) resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...) if err != nil { return nil, fmt.Errorf("failed to search dicom instances: %w", err) @@ -244,10 +231,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomseries/cloudhealthcaresearchdicomseries.go b/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomseries/cloudhealthcaresearchdicomseries.go index 00c51db961..0c888f8d9c 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomseries/cloudhealthcaresearchdicomseries.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomseries/cloudhealthcaresearchdicomseries.go @@ -65,11 +65,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -95,7 +90,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } params := parameters.Parameters{ @@ -117,17 +112,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedDICOMStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -137,15 +125,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -153,19 +135,24 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { return nil, err } - svc := t.Service + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores()) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } @@ -187,7 +174,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID) + name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID) resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForSeries(name, dicomWebPath).Do(opts...) if err != nil { return nil, fmt.Errorf("failed to search dicom series: %w", err) @@ -227,10 +214,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies/cloudhealthcaresearchdicomstudies.go b/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies/cloudhealthcaresearchdicomstudies.go index d22c8832e4..8a5e7ccf0d 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies/cloudhealthcaresearchdicomstudies.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies/cloudhealthcaresearchdicomstudies.go @@ -63,11 +63,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -93,7 +88,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } params := parameters.Parameters{ @@ -113,17 +108,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedDICOMStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -133,15 +121,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -149,19 +131,24 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { return nil, err } - svc := t.Service + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores()) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } @@ -171,7 +158,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, err } - name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID) + name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID) resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForStudies(name, "studies").Do(opts...) if err != nil { return nil, fmt.Errorf("failed to search dicom studies: %w", err) @@ -211,10 +198,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudmonitoring/cloudmonitoring.go b/internal/tools/cloudmonitoring/cloudmonitoring.go index acfebeb8ca..54c19f6774 100644 --- a/internal/tools/cloudmonitoring/cloudmonitoring.go +++ b/internal/tools/cloudmonitoring/cloudmonitoring.go @@ -23,7 +23,6 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - cloudmonitoringsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -44,6 +43,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + BaseURL() string + Client() *http.Client + UserAgent() string +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -60,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*cloudmonitoringsrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloudmonitoring`", kind) - } - // Define the parameters internally instead of from the config file. allParameters := parameters.Parameters{ parameters.NewStringParameterWithRequired("projectId", "The Id of the Google Cloud project.", true), @@ -83,9 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, AllParams: allParameters, - BaseURL: s.BaseURL, - UserAgent: s.UserAgent, - Client: s.Client, manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -97,9 +87,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config AllParams parameters.Parameters `yaml:"allParams"` - BaseURL string `yaml:"baseURL"` - UserAgent string - Client *http.Client manifest tools.Manifest mcpManifest tools.McpManifest } @@ -109,6 +96,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() projectID, ok := paramsMap["projectId"].(string) if !ok { @@ -119,7 +111,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("query parameter not found or not a string") } - url := fmt.Sprintf("%s/v1/projects/%s/location/global/prometheus/api/v1/query", t.BaseURL, projectID) + url := fmt.Sprintf("%s/v1/projects/%s/location/global/prometheus/api/v1/query", source.BaseURL(), projectID) req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { @@ -130,9 +122,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para q.Add("query", query) req.URL.RawQuery = q.Encode() - req.Header.Set("User-Agent", t.UserAgent) + req.Header.Set("User-Agent", source.UserAgent()) - resp, err := t.Client.Do(req) + resp, err := source.Client().Do(req) if err != nil { return nil, err } @@ -175,10 +167,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudmonitoring/cloudmonitoring_test.go b/internal/tools/cloudmonitoring/cloudmonitoring_test.go index 4707adafec..51c4d00c21 100644 --- a/internal/tools/cloudmonitoring/cloudmonitoring_test.go +++ b/internal/tools/cloudmonitoring/cloudmonitoring_test.go @@ -81,22 +81,6 @@ func TestInitialize(t *testing.T) { AuthRequired: []string{"google-auth-service"}, }, }, - { - desc: "Error: source not found", - cfg: cloudmonitoring.Config{ - Name: "test-tool", - Source: "non-existent-source", - }, - wantErr: `no source named "non-existent-source" configured`, - }, - { - desc: "Error: incompatible source kind", - cfg: cloudmonitoring.Config{ - Name: "test-tool", - Source: "incompatible-source", - }, - wantErr: "invalid source for \"cloud-monitoring-query-prometheus\" tool", - }, } for _, tc := range testCases { diff --git a/internal/tools/cloudsql/cloudsqlcloneinstance/cloudsqlcloneinstance.go b/internal/tools/cloudsql/cloudsqlcloneinstance/cloudsqlcloneinstance.go index c644a5a0e6..e8f7431f8b 100644 --- a/internal/tools/cloudsql/cloudsqlcloneinstance/cloudsqlcloneinstance.go +++ b/internal/tools/cloudsql/cloudsqlcloneinstance/cloudsqlcloneinstance.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" sqladmin "google.golang.org/api/sqladmin/v1" @@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the clone-instance tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -65,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*cloudsqladmin.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -98,7 +103,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -108,7 +112,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the clone-instance tool. type Tool struct { Config - Source *cloudsqladmin.Source AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest @@ -120,6 +123,10 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -156,7 +163,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para CloneContext: cloneContext, } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -189,10 +196,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudsql/cloudsqlcreatedatabase/cloudsqlcreatedatabase.go b/internal/tools/cloudsql/cloudsqlcreatedatabase/cloudsqlcreatedatabase.go index 6f4a4b11a4..57b4cc06d6 100644 --- a/internal/tools/cloudsql/cloudsqlcreatedatabase/cloudsqlcreatedatabase.go +++ b/internal/tools/cloudsql/cloudsqlcreatedatabase/cloudsqlcreatedatabase.go @@ -20,7 +20,6 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" sqladmin "google.golang.org/api/sqladmin/v1" @@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the create-database tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -65,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*cloudsqladmin.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -93,7 +98,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -103,7 +107,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the create-database tool. type Tool struct { Config - Source *cloudsqladmin.Source AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest @@ -115,6 +118,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -136,7 +144,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para Instance: instance, } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -169,10 +177,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudsql/cloudsqlcreateusers/cloudsqlcreateusers.go b/internal/tools/cloudsql/cloudsqlcreateusers/cloudsqlcreateusers.go index 71ac68c217..148ccfeb6c 100644 --- a/internal/tools/cloudsql/cloudsqlcreateusers/cloudsqlcreateusers.go +++ b/internal/tools/cloudsql/cloudsqlcreateusers/cloudsqlcreateusers.go @@ -20,7 +20,6 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" sqladmin "google.golang.org/api/sqladmin/v1" @@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the create-user tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -65,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*cloudsqladmin.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -95,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -105,7 +109,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the create-user tool. type Tool struct { Config - Source *cloudsqladmin.Source AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest @@ -117,6 +120,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -149,7 +157,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para user.Password = password } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -182,10 +190,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudsql/cloudsqlgetinstances/cloudsqlgetinstances.go b/internal/tools/cloudsql/cloudsqlgetinstances/cloudsqlgetinstances.go index d1ecc621f0..1fb40b67bc 100644 --- a/internal/tools/cloudsql/cloudsqlgetinstances/cloudsqlgetinstances.go +++ b/internal/tools/cloudsql/cloudsqlgetinstances/cloudsqlgetinstances.go @@ -20,9 +20,9 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "google.golang.org/api/sqladmin/v1" ) const kind string = "cloud-sql-get-instance" @@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the get-instances tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -65,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*cloudsqladmin.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("projectId", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -92,7 +98,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -102,7 +107,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the get-instances tool. type Tool struct { Config - Source *cloudsqladmin.Source AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest @@ -114,6 +118,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() projectId, ok := paramsMap["projectId"].(string) @@ -125,7 +134,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("missing 'instanceId' parameter") } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -158,10 +167,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudsql/cloudsqllistdatabases/cloudsqllistdatabases.go b/internal/tools/cloudsql/cloudsqllistdatabases/cloudsqllistdatabases.go index 32c9f01f01..ba54380631 100644 --- a/internal/tools/cloudsql/cloudsqllistdatabases/cloudsqllistdatabases.go +++ b/internal/tools/cloudsql/cloudsqllistdatabases/cloudsqllistdatabases.go @@ -20,9 +20,9 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - cloudsqladminsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "google.golang.org/api/sqladmin/v1" ) const kind string = "cloud-sql-list-databases" @@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the list-databases tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -64,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*cloudsqladminsrc.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -91,7 +97,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -102,7 +107,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) type Tool struct { Config AllParams parameters.Parameters `yaml:"allParams"` - Source *cloudsqladminsrc.Source manifest tools.Manifest mcpManifest tools.McpManifest } @@ -113,6 +117,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -124,7 +133,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("missing 'instance' parameter") } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -176,10 +185,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudsql/cloudsqllistinstances/cloudsqllistinstances.go b/internal/tools/cloudsql/cloudsqllistinstances/cloudsqllistinstances.go index 51d5829c0b..11ccd91bad 100644 --- a/internal/tools/cloudsql/cloudsqllistinstances/cloudsqllistinstances.go +++ b/internal/tools/cloudsql/cloudsqllistinstances/cloudsqllistinstances.go @@ -20,9 +20,9 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - cloudsqladminsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "google.golang.org/api/sqladmin/v1" ) const kind string = "cloud-sql-list-instances" @@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the list-instance tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -64,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*cloudsqladminsrc.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -90,7 +96,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -101,7 +106,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) type Tool struct { Config AllParams parameters.Parameters `yaml:"allParams"` - source *cloudsqladminsrc.Source manifest tools.Manifest mcpManifest tools.McpManifest } @@ -112,6 +116,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -119,7 +128,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("missing 'project' parameter") } - service, err := t.source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -169,10 +178,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudsql/cloudsqlwaitforoperation/cloudsqlwaitforoperation.go b/internal/tools/cloudsql/cloudsqlwaitforoperation/cloudsqlwaitforoperation.go index f46dc9c724..672f999282 100644 --- a/internal/tools/cloudsql/cloudsqlwaitforoperation/cloudsqlwaitforoperation.go +++ b/internal/tools/cloudsql/cloudsqlwaitforoperation/cloudsqlwaitforoperation.go @@ -25,9 +25,9 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "google.golang.org/api/sqladmin/v1" ) const kind string = "cloud-sql-wait-for-operation" @@ -87,6 +87,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the wait-for-operation tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -118,12 +124,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*cloudsqladmin.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -177,7 +183,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -191,17 +196,15 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the wait-for-operation tool. type Tool struct { Config - Source *cloudsqladmin.Source - AllParams parameters.Parameters `yaml:"allParams"` + AllParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest // Polling configuration Delay time.Duration MaxDelay time.Duration Multiplier float64 MaxRetries int - - manifest tools.Manifest - mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -210,6 +213,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -221,7 +229,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("missing 'operation' parameter") } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -267,7 +275,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("could not unmarshal operation: %w", err) } - if msg, ok := t.generateCloudSQLConnectionMessage(data); ok { + if msg, ok := t.generateCloudSQLConnectionMessage(source, data); ok { return msg, nil } return string(opBytes), nil @@ -305,11 +313,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) generateCloudSQLConnectionMessage(opResponse map[string]any) (string, bool) { +func (t Tool) generateCloudSQLConnectionMessage(source compatibleSource, opResponse map[string]any) (string, bool) { operationType, ok := opResponse["operationType"].(string) if !ok || operationType != "CREATE_DATABASE" { return "", false @@ -329,7 +341,7 @@ func (t Tool) generateCloudSQLConnectionMessage(opResponse map[string]any) (stri instance := matches[2] database := matches[3] - instanceData, err := t.fetchInstanceData(context.Background(), project, instance) + instanceData, err := t.fetchInstanceData(context.Background(), source, project, instance) if err != nil { fmt.Printf("error fetching instance data: %v\n", err) return "", false @@ -385,8 +397,8 @@ func (t Tool) generateCloudSQLConnectionMessage(opResponse map[string]any) (stri return b.String(), true } -func (t Tool) fetchInstanceData(ctx context.Context, project, instance string) (map[string]any, error) { - service, err := t.Source.GetService(ctx, "") +func (t Tool) fetchInstanceData(ctx context.Context, source compatibleSource, project, instance string) (map[string]any, error) { + service, err := source.GetService(ctx, "") if err != nil { return nil, err } @@ -408,6 +420,6 @@ func (t Tool) fetchInstanceData(ctx context.Context, project, instance string) ( return data, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance/cloudsqlmssqlcreateinstance.go b/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance/cloudsqlmssqlcreateinstance.go index caa5aac470..78bc77d6fa 100644 --- a/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance/cloudsqlmssqlcreateinstance.go +++ b/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance/cloudsqlmssqlcreateinstance.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" sqladmin "google.golang.org/api/sqladmin/v1" @@ -43,6 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the create-instances tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*cloudsqladmin.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -96,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -106,7 +110,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the create-instances tool. type Tool struct { Config - Source *cloudsqladmin.Source AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest @@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -167,7 +175,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para Project: project, } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -200,10 +208,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance/cloudsqlmysqlcreateinstance.go b/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance/cloudsqlmysqlcreateinstance.go index d9eedb69df..165a057c35 100644 --- a/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance/cloudsqlmysqlcreateinstance.go +++ b/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance/cloudsqlmysqlcreateinstance.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" sqladmin "google.golang.org/api/sqladmin/v1" @@ -43,6 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the create-instances tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*cloudsqladmin.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -96,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -106,7 +110,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the create-instances tool. type Tool struct { Config - Source *cloudsqladmin.Source AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest @@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -167,7 +175,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para Project: project, } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -200,10 +208,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances/cloudsqlpgcreateinstances.go b/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances/cloudsqlpgcreateinstances.go index edbcecd652..224cc3700c 100644 --- a/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances/cloudsqlpgcreateinstances.go +++ b/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances/cloudsqlpgcreateinstances.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" sqladmin "google.golang.org/api/sqladmin/v1" @@ -43,6 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the create-instances tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*cloudsqladmin.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -96,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -106,7 +110,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the create-instances tool. type Tool struct { Config - Source *cloudsqladmin.Source AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest @@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -167,7 +175,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para Project: project, } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -200,10 +208,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudsqlpg/cloudsqlpgupgradeprecheck/cloudsqlpgupgradeprecheck.go b/internal/tools/cloudsqlpg/cloudsqlpgupgradeprecheck/cloudsqlpgupgradeprecheck.go index 5cde40216d..156d648e93 100644 --- a/internal/tools/cloudsqlpg/cloudsqlpgupgradeprecheck/cloudsqlpgupgradeprecheck.go +++ b/internal/tools/cloudsqlpg/cloudsqlpgupgradeprecheck/cloudsqlpgupgradeprecheck.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" sqladmin "google.golang.org/api/sqladmin/v1" @@ -43,6 +42,11 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the precheck-upgrade tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -62,15 +66,6 @@ func (cfg Config) ToolConfigKind() string { // Initialize initializes the tool from the configuration. func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - s, ok := rawS.(*cloudsqladmin.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) - } - allParameters := parameters.Parameters{ parameters.NewStringParameter("project", "The project ID"), parameters.NewStringParameter("instance", "The name of the instance to check"), @@ -88,28 +83,19 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) return Tool{ - Name: cfg.Name, - Kind: kind, - AuthRequired: cfg.AuthRequired, - Source: s, - AllParams: allParameters, - manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + AllParams: allParameters, + manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, }, nil } // Tool represents the precheck-upgrade tool. type Tool struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - Description string `yaml:"description"` - AuthRequired []string `yaml:"authRequired"` - - Source *cloudsqladmin.Source + Config AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest - Config } // PreCheckResultItem holds the details of a single check result. @@ -146,6 +132,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -162,7 +153,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("missing or empty 'targetDatabaseVersion' parameter") } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, fmt.Errorf("failed to get HTTP client from source: %w", err) } @@ -234,10 +225,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/couchbase/couchbase.go b/internal/tools/couchbase/couchbase.go index 2149691c82..481c9f6b22 100644 --- a/internal/tools/couchbase/couchbase.go +++ b/internal/tools/couchbase/couchbase.go @@ -22,7 +22,6 @@ import ( "github.com/couchbase/gocb/v2" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/couchbase" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -48,11 +47,6 @@ type compatibleSource interface { CouchbaseQueryScanConsistency() uint } -// validate compatible sources are still compatible -var _ compatibleSource = &couchbase.Source{} - -var compatibleSources = [...]string{couchbase.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -72,18 +66,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -92,12 +74,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) // finish tool setup t := Tool{ - Config: cfg, - AllParams: allParameters, - Scope: s.CouchbaseScope(), - QueryScanConsistency: s.CouchbaseQueryScanConsistency(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + AllParams: allParameters, + manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -107,12 +87,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Scope *gocb.Scope - QueryScanConsistency uint - manifest tools.Manifest - mcpManifest tools.McpManifest + AllParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -120,6 +97,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + namedParamsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, namedParamsMap) if err != nil { @@ -130,8 +112,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, fmt.Errorf("unable to extract standard params %w", err) } - results, err := t.Scope.Query(newStatement, &gocb.QueryOptions{ - ScanConsistency: gocb.QueryScanConsistency(t.QueryScanConsistency), + results, err := source.CouchbaseScope().Query(newStatement, &gocb.QueryOptions{ + ScanConsistency: gocb.QueryScanConsistency(source.CouchbaseQueryScanConsistency()), NamedParameters: newParams.AsMap(), }) if err != nil { @@ -166,10 +148,10 @@ func (t Tool) Authorized(verifiedAuthSources []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/dataform/dataformcompilelocal/dataformcompilelocal.go b/internal/tools/dataform/dataformcompilelocal/dataformcompilelocal.go index d7e88428e6..daf6d4f29d 100644 --- a/internal/tools/dataform/dataformcompilelocal/dataformcompilelocal.go +++ b/internal/tools/dataform/dataformcompilelocal/dataformcompilelocal.go @@ -118,10 +118,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/dataplex/dataplexlookupentry/dataplexlookupentry.go b/internal/tools/dataplex/dataplexlookupentry/dataplexlookupentry.go index cdf3f62c41..78915c7b96 100644 --- a/internal/tools/dataplex/dataplexlookupentry/dataplexlookupentry.go +++ b/internal/tools/dataplex/dataplexlookupentry/dataplexlookupentry.go @@ -22,7 +22,6 @@ import ( dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - dataplexds "github.com/googleapis/genai-toolbox/internal/sources/dataplex" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -47,11 +46,6 @@ type compatibleSource interface { CatalogClient() *dataplexapi.CatalogClient } -// validate compatible sources are still compatible -var _ compatibleSource = &dataplexds.Source{} - -var compatibleSources = [...]string{dataplexds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -69,17 +63,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // Initialize the search configuration with the provided sources - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - viewDesc := ` ## Argument: view @@ -104,9 +87,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) t := Tool{ - Config: cfg, - Parameters: params, - CatalogClient: s.CatalogClient(), + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -119,10 +101,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) type Tool struct { Config - Parameters parameters.Parameters - CatalogClient *dataplexapi.CatalogClient - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -130,6 +111,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() viewMap := map[int]dataplexpb.EntryView{ 1: dataplexpb.EntryView_BASIC, @@ -153,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para Entry: entry, } - result, err := t.CatalogClient.LookupEntry(ctx, req) + result, err := source.CatalogClient().LookupEntry(ctx, req) if err != nil { return nil, err } @@ -179,10 +165,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/dataplex/dataplexsearchaspecttypes/dataplexsearchaspecttypes.go b/internal/tools/dataplex/dataplexsearchaspecttypes/dataplexsearchaspecttypes.go index e450a15a55..37f44cf9ea 100644 --- a/internal/tools/dataplex/dataplexsearchaspecttypes/dataplexsearchaspecttypes.go +++ b/internal/tools/dataplex/dataplexsearchaspecttypes/dataplexsearchaspecttypes.go @@ -23,7 +23,6 @@ import ( "github.com/cenkalti/backoff/v5" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - dataplexds "github.com/googleapis/genai-toolbox/internal/sources/dataplex" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -49,11 +48,6 @@ type compatibleSource interface { ProjectID() string } -// validate compatible sources are still compatible -var _ compatibleSource = &dataplexds.Source{} - -var compatibleSources = [...]string{dataplexds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -70,17 +64,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // Initialize the search configuration with the provided sources - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - query := parameters.NewStringParameter("query", "The query against which aspect type should be matched.") pageSize := parameters.NewIntParameterWithDefault("pageSize", 5, "Number of returned aspect types in the search page.") orderBy := parameters.NewStringParameterWithDefault("orderBy", "relevance", "Specifies the ordering of results. Supported values are: relevance, last_modified_timestamp, last_modified_timestamp asc") @@ -89,10 +72,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) t := Tool{ - Config: cfg, - Parameters: params, - CatalogClient: s.CatalogClient(), - ProjectID: s.ProjectID(), + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -105,11 +86,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) type Tool struct { Config - Parameters parameters.Parameters - CatalogClient *dataplexapi.CatalogClient - ProjectID string - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -117,6 +96,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + // Invoke the tool with the provided parameters paramsMap := params.AsMap() query, _ := paramsMap["query"].(string) @@ -126,16 +110,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Create SearchEntriesRequest with the provided parameters req := &dataplexpb.SearchEntriesRequest{ Query: query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype", - Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID), + Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()), PageSize: pageSize, OrderBy: orderBy, SemanticSearch: true, } // Perform the search using the CatalogClient - this will return an iterator - it := t.CatalogClient.SearchEntries(ctx, req) + it := source.CatalogClient().SearchEntries(ctx, req) if it == nil { - return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID) + return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID()) } // Create an instance of exponential backoff with default values for retrying GetAspectType calls @@ -155,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } operation := func() (*dataplexpb.AspectType, error) { - aspectType, err := t.CatalogClient.GetAspectType(ctx, getAspectTypeReq) + aspectType, err := source.CatalogClient().GetAspectType(ctx, getAspectTypeReq) if err != nil { return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err) } @@ -192,10 +176,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/dataplex/dataplexsearchentries/dataplexsearchentries.go b/internal/tools/dataplex/dataplexsearchentries/dataplexsearchentries.go index 114601bbb1..76c3208bbf 100644 --- a/internal/tools/dataplex/dataplexsearchentries/dataplexsearchentries.go +++ b/internal/tools/dataplex/dataplexsearchentries/dataplexsearchentries.go @@ -22,7 +22,6 @@ import ( dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - dataplexds "github.com/googleapis/genai-toolbox/internal/sources/dataplex" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -48,11 +47,6 @@ type compatibleSource interface { ProjectID() string } -// validate compatible sources are still compatible -var _ compatibleSource = &dataplexds.Source{} - -var compatibleSources = [...]string{dataplexds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -69,17 +63,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // Initialize the search configuration with the provided sources - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - query := parameters.NewStringParameter("query", "The query against which entries in scope should be matched.") pageSize := parameters.NewIntParameterWithDefault("pageSize", 5, "Number of results in the search page.") orderBy := parameters.NewStringParameterWithDefault("orderBy", "relevance", "Specifies the ordering of results. Supported values are: relevance, last_modified_timestamp, last_modified_timestamp asc") @@ -88,10 +71,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) t := Tool{ - Config: cfg, - Parameters: params, - CatalogClient: s.CatalogClient(), - ProjectID: s.ProjectID(), + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -104,11 +85,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) type Tool struct { Config - Parameters parameters.Parameters - CatalogClient *dataplexapi.CatalogClient - ProjectID string - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -116,6 +95,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() query, _ := paramsMap["query"].(string) pageSize := int32(paramsMap["pageSize"].(int)) @@ -123,15 +107,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para req := &dataplexpb.SearchEntriesRequest{ Query: query, - Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID), + Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()), PageSize: pageSize, OrderBy: orderBy, SemanticSearch: true, } - it := t.CatalogClient.SearchEntries(ctx, req) + it := source.CatalogClient().SearchEntries(ctx, req) if it == nil { - return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID) + return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID()) } var results []*dataplexpb.SearchEntriesResult @@ -163,10 +147,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/dgraph/dgraph.go b/internal/tools/dgraph/dgraph.go index 4615a177a2..beef9f86a5 100644 --- a/internal/tools/dgraph/dgraph.go +++ b/internal/tools/dgraph/dgraph.go @@ -46,11 +46,6 @@ type compatibleSource interface { DgraphClient() *dgraph.DgraphClient } -// validate compatible sources are still compatible -var _ compatibleSource = &dgraph.Source{} - -var compatibleSources = [...]string{dgraph.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,26 +66,13 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, cfg.Parameters, nil) // finish tool setup t := Tool{ - Config: cfg, - DgraphClient: s.DgraphClient(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -100,9 +82,8 @@ var _ tools.Tool = Tool{} type Tool struct { Config - DgraphClient *dgraph.DgraphClient - manifest tools.Manifest - mcpManifest tools.McpManifest + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -110,9 +91,14 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMapWithDollarPrefix() - resp, err := t.DgraphClient.ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout) + resp, err := source.DgraphClient().ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout) if err != nil { return nil, err } @@ -148,10 +134,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go b/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go index 7b432bee63..d7cbb35722 100644 --- a/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go +++ b/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go @@ -43,10 +43,6 @@ type compatibleSource interface { ElasticsearchClient() es.EsClient } -var _ compatibleSource = &es.Source{} - -var compatibleSources = [...]string{es.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -77,29 +73,15 @@ type Tool struct { Config manifest tools.Manifest mcpManifest tools.McpManifest - EsClient es.EsClient } var _ tools.Tool = Tool{} func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - src, ok := srcs[c.Source] - if !ok { - return nil, fmt.Errorf("source %q not found", c.Source) - } - - // verify the source is compatible - s, ok := src.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - mcpManifest := tools.GetMcpManifest(c.Name, c.Description, c.AuthRequired, c.Parameters, nil) return Tool{ Config: c, - EsClient: s.ElasticsearchClient(), manifest: tools.Manifest{Description: c.Description, Parameters: c.Parameters.Manifest(), AuthRequired: c.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -120,6 +102,11 @@ type esqlResult struct { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + var cancel context.CancelFunc if t.Timeout > 0 { ctx, cancel = context.WithTimeout(ctx, time.Duration(t.Timeout)*time.Second) @@ -164,8 +151,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para Body: bytes.NewReader(body), Format: t.Format, FilterPath: []string{"columns", "values"}, - Instrument: t.EsClient.InstrumentationEnabled(), - }.Do(ctx, t.EsClient) + Instrument: source.ElasticsearchClient().InstrumentationEnabled(), + }.Do(ctx, source.ElasticsearchClient()) if err != nil { return nil, err @@ -230,10 +217,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firebird/firebirdexecutesql/firebirdexecutesql.go b/internal/tools/firebird/firebirdexecutesql/firebirdexecutesql.go index 97fd2296e7..28c8d0fb63 100644 --- a/internal/tools/firebird/firebirdexecutesql/firebirdexecutesql.go +++ b/internal/tools/firebird/firebirdexecutesql/firebirdexecutesql.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/firebird" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -47,10 +46,6 @@ type compatibleSource interface { FirebirdDB() *sql.DB } -var _ compatibleSource = &firebird.Source{} - -var compatibleSources = [...]string{firebird.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -66,16 +61,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") params := parameters.Parameters{sqlParameter} @@ -84,7 +69,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Db: s.FirebirdDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -95,9 +79,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Db *sql.DB + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -107,6 +89,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { @@ -120,7 +107,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - rows, err := t.Db.QueryContext(ctx, sql) + rows, err := source.FirebirdDB().QueryContext(ctx, sql) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -180,10 +167,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firebird/firebirdsql/firebirdsql.go b/internal/tools/firebird/firebirdsql/firebirdsql.go index f249dca46f..9dd040dcd7 100644 --- a/internal/tools/firebird/firebirdsql/firebirdsql.go +++ b/internal/tools/firebird/firebirdsql/firebirdsql.go @@ -22,7 +22,6 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/firebird" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -47,11 +46,6 @@ type compatibleSource interface { FirebirdDB() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &firebird.Source{} - -var compatibleSources = [...]string{firebird.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -94,7 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Db: s.FirebirdDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -106,9 +87,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Db *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -118,6 +97,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() statement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -142,7 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - rows, err := t.Db.QueryContext(ctx, statement, namedArgs...) + rows, err := source.FirebirdDB().QueryContext(ctx, statement, namedArgs...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -204,10 +188,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firestore/firestoreadddocuments/firestoreadddocuments.go b/internal/tools/firestore/firestoreadddocuments/firestoreadddocuments.go index 18d2d9354a..a1cf8b5bd8 100644 --- a/internal/tools/firestore/firestoreadddocuments/firestoreadddocuments.go +++ b/internal/tools/firestore/firestoreadddocuments/firestoreadddocuments.go @@ -21,7 +21,6 @@ import ( firestoreapi "cloud.google.com/go/firestore" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -50,11 +49,6 @@ type compatibleSource interface { FirestoreClient() *firestoreapi.Client } -// validate compatible sources are still compatible -var _ compatibleSource = &firestoreds.Source{} - -var compatibleSources = [...]string{firestoreds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - // Create parameters collectionPathParameter := parameters.NewStringParameter( collectionPathKey, @@ -124,7 +106,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Client: s.FirestoreClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -136,9 +117,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Client *firestoreapi.Client + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -148,6 +127,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() // Get collection path @@ -169,7 +153,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Convert the document data from JSON format to Firestore format // The client is passed to handle referenceValue types - documentData, err := util.JSONToFirestoreValue(documentDataRaw, t.Client) + documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient()) if err != nil { return nil, fmt.Errorf("failed to convert document data: %w", err) } @@ -181,7 +165,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } // Get the collection reference - collection := t.Client.Collection(collectionPath) + collection := source.FirestoreClient().Collection(collectionPath) // Add the document to the collection docRef, writeResult, err := collection.Add(ctx, documentData) @@ -221,10 +205,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firestore/firestoredeletedocuments/firestoredeletedocuments.go b/internal/tools/firestore/firestoredeletedocuments/firestoredeletedocuments.go index b1d95a58e2..00dfffccd3 100644 --- a/internal/tools/firestore/firestoredeletedocuments/firestoredeletedocuments.go +++ b/internal/tools/firestore/firestoredeletedocuments/firestoredeletedocuments.go @@ -21,7 +21,6 @@ import ( firestoreapi "cloud.google.com/go/firestore" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -48,11 +47,6 @@ type compatibleSource interface { FirestoreClient() *firestoreapi.Client } -// validate compatible sources are still compatible -var _ compatibleSource = &firestoreds.Source{} - -var compatibleSources = [...]string{firestoreds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -69,18 +63,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - documentPathsParameter := parameters.NewArrayParameter(documentPathsKey, "Array of relative document paths to delete from Firestore (e.g., 'users/userId' or 'users/userId/posts/postId'). Note: These are relative paths, NOT absolute paths like 'projects/{project_id}/databases/{database_id}/documents/...'", parameters.NewStringParameter("item", "Relative document path")) params := parameters.Parameters{documentPathsParameter} @@ -90,7 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Client: s.FirestoreClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -102,9 +83,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Client *firestoreapi.Client + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -114,6 +93,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() documentPathsRaw, ok := mapParams[documentPathsKey].([]any) if !ok { @@ -143,14 +127,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } // Create a BulkWriter to handle multiple deletions efficiently - bulkWriter := t.Client.BulkWriter(ctx) + bulkWriter := source.FirestoreClient().BulkWriter(ctx) // Keep track of jobs for each document jobs := make([]*firestoreapi.BulkWriterJob, len(documentPaths)) // Add all delete operations to the BulkWriter for i, path := range documentPaths { - docRef := t.Client.Doc(path) + docRef := source.FirestoreClient().Doc(path) job, err := bulkWriter.Delete(docRef) if err != nil { return nil, fmt.Errorf("failed to add delete operation for document %q: %w", path, err) @@ -198,10 +182,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firestore/firestoregetdocuments/firestoregetdocuments.go b/internal/tools/firestore/firestoregetdocuments/firestoregetdocuments.go index 0a9666d8d8..9b8c253f5e 100644 --- a/internal/tools/firestore/firestoregetdocuments/firestoregetdocuments.go +++ b/internal/tools/firestore/firestoregetdocuments/firestoregetdocuments.go @@ -21,7 +21,6 @@ import ( firestoreapi "cloud.google.com/go/firestore" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -48,11 +47,6 @@ type compatibleSource interface { FirestoreClient() *firestoreapi.Client } -// validate compatible sources are still compatible -var _ compatibleSource = &firestoreds.Source{} - -var compatibleSources = [...]string{firestoreds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -69,18 +63,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - documentPathsParameter := parameters.NewArrayParameter(documentPathsKey, "Array of relative document paths to retrieve from Firestore (e.g., 'users/userId' or 'users/userId/posts/postId'). Note: These are relative paths, NOT absolute paths like 'projects/{project_id}/databases/{database_id}/documents/...'", parameters.NewStringParameter("item", "Relative document path")) params := parameters.Parameters{documentPathsParameter} @@ -90,7 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Client: s.FirestoreClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -102,9 +83,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Client *firestoreapi.Client + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -114,6 +93,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() documentPathsRaw, ok := mapParams[documentPathsKey].([]any) if !ok { @@ -145,11 +129,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Create document references from paths docRefs := make([]*firestoreapi.DocumentRef, len(documentPaths)) for i, path := range documentPaths { - docRefs[i] = t.Client.Doc(path) + docRefs[i] = source.FirestoreClient().Doc(path) } // Get all documents - snapshots, err := t.Client.GetAll(ctx, docRefs) + snapshots, err := source.FirestoreClient().GetAll(ctx, docRefs) if err != nil { return nil, fmt.Errorf("failed to get documents: %w", err) } @@ -190,10 +174,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firestore/firestoregetrules/firestoregetrules.go b/internal/tools/firestore/firestoregetrules/firestoregetrules.go index eb958c445c..b05f6ff878 100644 --- a/internal/tools/firestore/firestoregetrules/firestoregetrules.go +++ b/internal/tools/firestore/firestoregetrules/firestoregetrules.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/firebaserules/v1" @@ -48,11 +47,6 @@ type compatibleSource interface { GetDatabaseId() string } -// validate compatible sources are still compatible -var _ compatibleSource = &firestoreds.Source{} - -var compatibleSources = [...]string{firestoreds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -69,18 +63,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - // No parameters needed for this tool params := parameters.Parameters{} @@ -90,9 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - RulesClient: s.FirebaseRulesClient(), - ProjectId: s.GetProjectId(), - DatabaseId: s.GetDatabaseId(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -104,11 +83,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - RulesClient *firebaserules.Service - ProjectId string - DatabaseId string + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -118,19 +93,24 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + // Get the latest release for Firestore - releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", t.ProjectId, t.DatabaseId) - release, err := t.RulesClient.Projects.Releases.Get(releaseName).Context(ctx).Do() + releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", source.GetProjectId(), source.GetDatabaseId()) + release, err := source.FirebaseRulesClient().Projects.Releases.Get(releaseName).Context(ctx).Do() if err != nil { return nil, fmt.Errorf("failed to get latest Firestore release: %w", err) } if release.RulesetName == "" { - return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", t.ProjectId, t.DatabaseId) + return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", source.GetProjectId(), source.GetDatabaseId()) } // Get the ruleset content - ruleset, err := t.RulesClient.Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do() + ruleset, err := source.FirebaseRulesClient().Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do() if err != nil { return nil, fmt.Errorf("failed to get ruleset content: %w", err) } @@ -158,10 +138,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firestore/firestorelistcollections/firestorelistcollections.go b/internal/tools/firestore/firestorelistcollections/firestorelistcollections.go index 382161099a..af3df39dfa 100644 --- a/internal/tools/firestore/firestorelistcollections/firestorelistcollections.go +++ b/internal/tools/firestore/firestorelistcollections/firestorelistcollections.go @@ -21,7 +21,6 @@ import ( firestoreapi "cloud.google.com/go/firestore" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -48,11 +47,6 @@ type compatibleSource interface { FirestoreClient() *firestoreapi.Client } -// validate compatible sources are still compatible -var _ compatibleSource = &firestoreds.Source{} - -var compatibleSources = [...]string{firestoreds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -69,18 +63,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - emptyString := "" parentPathParameter := parameters.NewStringParameterWithDefault(parentPathKey, emptyString, "Relative parent document path to list subcollections from (e.g., 'users/userId'). If not provided, lists root collections. Note: This is a relative path, NOT an absolute path like 'projects/{project_id}/databases/{database_id}/documents/...'") params := parameters.Parameters{parentPathParameter} @@ -91,7 +73,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Client: s.FirestoreClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -103,9 +84,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Client *firestoreapi.Client + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -115,10 +94,14 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() var collectionRefs []*firestoreapi.CollectionRef - var err error // Check if parentPath is provided parentPath, hasParent := mapParams[parentPathKey].(string) @@ -130,14 +113,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } // List subcollections of the specified document - docRef := t.Client.Doc(parentPath) + docRef := source.FirestoreClient().Doc(parentPath) collectionRefs, err = docRef.Collections(ctx).GetAll() if err != nil { return nil, fmt.Errorf("failed to list subcollections of document %q: %w", parentPath, err) } } else { // List root collections - collectionRefs, err = t.Client.Collections(ctx).GetAll() + collectionRefs, err = source.FirestoreClient().Collections(ctx).GetAll() if err != nil { return nil, fmt.Errorf("failed to list root collections: %w", err) } @@ -177,10 +160,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firestore/firestorequery/firestorequery.go b/internal/tools/firestore/firestorequery/firestorequery.go index 8ae527452c..9434e57171 100644 --- a/internal/tools/firestore/firestorequery/firestorequery.go +++ b/internal/tools/firestore/firestorequery/firestorequery.go @@ -24,7 +24,6 @@ import ( firestoreapi "cloud.google.com/go/firestore" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -52,12 +51,9 @@ var validOperators = map[string]bool{ // Error messages const ( - errFilterParseFailed = "failed to parse filters: %w" - errQueryExecutionFailed = "failed to execute query: %w" - errTemplateParseFailed = "failed to parse template: %w" - errTemplateExecFailed = "failed to execute template: %w" - errLimitParseFailed = "failed to parse limit value '%s': %w" - errSelectFieldParseFailed = "failed to parse select field: %w" + errFilterParseFailed = "failed to parse filters: %w" + errQueryExecutionFailed = "failed to execute query: %w" + errLimitParseFailed = "failed to parse limit value '%s': %w" ) func init() { @@ -79,11 +75,6 @@ type compatibleSource interface { FirestoreClient() *firestoreapi.Client } -// validate compatible sources are still compatible -var _ compatibleSource = &firestoreds.Source{} - -var compatibleSources = [...]string{firestoreds.SourceKind} - // Config represents the configuration for the Firestore query tool type Config struct { Name string `yaml:"name" validate:"required"` @@ -114,18 +105,6 @@ func (cfg Config) ToolConfigKind() string { // Initialize creates a new Tool instance from the configuration func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - // Set default limit if not specified if cfg.Limit == "" { cfg.Limit = fmt.Sprintf("%d", defaultLimit) @@ -137,7 +116,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ Config: cfg, - Client: s.FirestoreClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -201,6 +179,11 @@ type QueryResponse struct { // Invoke executes the Firestore query based on the provided parameters func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() // Process collection path with template substitution @@ -210,7 +193,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } // Build the query - query, err := t.buildQuery(collectionPath, paramsMap) + query, err := t.buildQuery(source, collectionPath, paramsMap) if err != nil { return nil, err } @@ -220,8 +203,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } // buildQuery constructs the Firestore query from parameters -func (t Tool) buildQuery(collectionPath string, params map[string]any) (*firestoreapi.Query, error) { - collection := t.Client.Collection(collectionPath) +func (t Tool) buildQuery(source compatibleSource, collectionPath string, params map[string]any) (*firestoreapi.Query, error) { + collection := source.FirestoreClient().Collection(collectionPath) query := collection.Query // Process and apply filters if template is provided @@ -239,7 +222,7 @@ func (t Tool) buildQuery(collectionPath string, params map[string]any) (*firesto } // Convert simplified filter to Firestore filter - if filter := t.convertToFirestoreFilter(simplifiedFilter); filter != nil { + if filter := t.convertToFirestoreFilter(source, simplifiedFilter); filter != nil { query = query.WhereEntity(filter) } } @@ -280,12 +263,12 @@ func (t Tool) buildQuery(collectionPath string, params map[string]any) (*firesto } // convertToFirestoreFilter converts simplified filter format to Firestore EntityFilter -func (t Tool) convertToFirestoreFilter(filter SimplifiedFilter) firestoreapi.EntityFilter { +func (t Tool) convertToFirestoreFilter(source compatibleSource, filter SimplifiedFilter) firestoreapi.EntityFilter { // Handle AND filters if len(filter.And) > 0 { filters := make([]firestoreapi.EntityFilter, 0, len(filter.And)) for _, f := range filter.And { - if converted := t.convertToFirestoreFilter(f); converted != nil { + if converted := t.convertToFirestoreFilter(source, f); converted != nil { filters = append(filters, converted) } } @@ -299,7 +282,7 @@ func (t Tool) convertToFirestoreFilter(filter SimplifiedFilter) firestoreapi.Ent if len(filter.Or) > 0 { filters := make([]firestoreapi.EntityFilter, 0, len(filter.Or)) for _, f := range filter.Or { - if converted := t.convertToFirestoreFilter(f); converted != nil { + if converted := t.convertToFirestoreFilter(source, f); converted != nil { filters = append(filters, converted) } } @@ -313,7 +296,7 @@ func (t Tool) convertToFirestoreFilter(filter SimplifiedFilter) firestoreapi.Ent if filter.Field != "" && filter.Op != "" && filter.Value != nil { if validOperators[filter.Op] { // Convert the value using the Firestore native JSON converter - convertedValue, err := util.JSONToFirestoreValue(filter.Value, t.Client) + convertedValue, err := util.JSONToFirestoreValue(filter.Value, source.FirestoreClient()) if err != nil { // If conversion fails, use the original value convertedValue = filter.Value @@ -525,10 +508,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firestore/firestorequerycollection/firestorequerycollection.go b/internal/tools/firestore/firestorequerycollection/firestorequerycollection.go index 72c4d27086..9601ecc099 100644 --- a/internal/tools/firestore/firestorequerycollection/firestorequerycollection.go +++ b/internal/tools/firestore/firestorequerycollection/firestorequerycollection.go @@ -23,7 +23,6 @@ import ( firestoreapi "cloud.google.com/go/firestore" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -92,11 +91,6 @@ type compatibleSource interface { FirestoreClient() *firestoreapi.Client } -// validate compatible sources are still compatible -var _ compatibleSource = &firestoreds.Source{} - -var compatibleSources = [...]string{firestoreds.SourceKind} - // Config represents the configuration for the Firestore query collection tool type Config struct { Name string `yaml:"name" validate:"required"` @@ -116,18 +110,6 @@ func (cfg Config) ToolConfigKind() string { // Initialize creates a new Tool instance from the configuration func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - // Create parameters params := createParameters() @@ -137,7 +119,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Client: s.FirestoreClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -199,9 +180,7 @@ var _ tools.Tool = Tool{} // Tool represents the Firestore query collection tool type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Client *firestoreapi.Client + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -266,6 +245,11 @@ type QueryResponse struct { // Invoke executes the Firestore query based on the provided parameters func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + // Parse parameters queryParams, err := t.parseQueryParameters(params) if err != nil { @@ -273,7 +257,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } // Build the query - query, err := t.buildQuery(queryParams) + query, err := t.buildQuery(source, queryParams) if err != nil { return nil, err } @@ -396,8 +380,8 @@ func (t Tool) parseOrderBy(orderByRaw interface{}) (*OrderByConfig, error) { } // buildQuery constructs the Firestore query from parameters -func (t Tool) buildQuery(params *queryParameters) (*firestoreapi.Query, error) { - collection := t.Client.Collection(params.CollectionPath) +func (t Tool) buildQuery(source compatibleSource, params *queryParameters) (*firestoreapi.Query, error) { + collection := source.FirestoreClient().Collection(params.CollectionPath) query := collection.Query // Apply filters @@ -531,10 +515,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument.go b/internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument.go index fa5576ce31..d08fdb9458 100644 --- a/internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument.go +++ b/internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument.go @@ -22,7 +22,6 @@ import ( firestoreapi "cloud.google.com/go/firestore" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -52,11 +51,6 @@ type compatibleSource interface { FirestoreClient() *firestoreapi.Client } -// validate compatible sources are still compatible -var _ compatibleSource = &firestoreds.Source{} - -var compatibleSources = [...]string{firestoreds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -73,18 +67,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - // Create parameters documentPathParameter := parameters.NewStringParameter( documentPathKey, @@ -134,7 +116,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Client: s.FirestoreClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -146,9 +127,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Client *firestoreapi.Client + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -158,6 +137,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() // Get document path @@ -200,7 +184,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } // Get the document reference - docRef := t.Client.Doc(documentPath) + docRef := source.FirestoreClient().Doc(documentPath) // Prepare update data var writeResult *firestoreapi.WriteResult @@ -211,7 +195,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para updates := make([]firestoreapi.Update, 0, len(updatePaths)) // Convert document data without delete markers - dataMap, err := util.JSONToFirestoreValue(documentDataRaw, t.Client) + dataMap, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient()) if err != nil { return nil, fmt.Errorf("failed to convert document data: %w", err) } @@ -239,7 +223,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para writeResult, writeErr = docRef.Update(ctx, updates) } else { // Update all fields in the document data (merge) - documentData, err := util.JSONToFirestoreValue(documentDataRaw, t.Client) + documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient()) if err != nil { return nil, fmt.Errorf("failed to convert document data: %w", err) } @@ -314,10 +298,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument_test.go b/internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument_test.go index de2e3be40f..3311aeb86e 100644 --- a/internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument_test.go +++ b/internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument_test.go @@ -132,32 +132,6 @@ func TestConfig_Initialize(t *testing.T) { }, wantErr: false, }, - { - name: "source not found", - config: Config{ - Name: "test-update-document", - Kind: "firestore-update-document", - Source: "missing-source", - Description: "Update a document", - }, - sources: map[string]sources.Source{}, - wantErr: true, - errMsg: "no source named \"missing-source\" configured", - }, - { - name: "incompatible source", - config: Config{ - Name: "test-update-document", - Kind: "firestore-update-document", - Source: "wrong-source", - Description: "Update a document", - }, - sources: map[string]sources.Source{ - "wrong-source": &mockIncompatibleSource{}, - }, - wantErr: true, - errMsg: "invalid source for \"firestore-update-document\" tool", - }, } for _, tt := range tests { @@ -464,14 +438,3 @@ func TestGetFieldValue(t *testing.T) { }) } } - -// mockIncompatibleSource is a mock source that doesn't implement compatibleSource -type mockIncompatibleSource struct{} - -func (m *mockIncompatibleSource) SourceKind() string { - return "mock" -} - -func (m *mockIncompatibleSource) ToConfig() sources.SourceConfig { - return nil -} diff --git a/internal/tools/firestore/firestorevalidaterules/firestorevalidaterules.go b/internal/tools/firestore/firestorevalidaterules/firestorevalidaterules.go index 028677cc99..69cbee4aa4 100644 --- a/internal/tools/firestore/firestorevalidaterules/firestorevalidaterules.go +++ b/internal/tools/firestore/firestorevalidaterules/firestorevalidaterules.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/firebaserules/v1" @@ -53,11 +52,6 @@ type compatibleSource interface { GetProjectId() string } -// validate compatible sources are still compatible -var _ compatibleSource = &firestoreds.Source{} - -var compatibleSources = [...]string{firestoreds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -74,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - // Create parameters params := createParameters() mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) @@ -94,8 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - RulesClient: s.FirebaseRulesClient(), - ProjectId: s.GetProjectId(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -117,10 +97,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - RulesClient *firebaserules.Service - ProjectId string + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -154,11 +131,16 @@ type ValidationResult struct { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() // Get source parameter - source, ok := mapParams[sourceKey].(string) - if !ok || source == "" { + sourceParam, ok := mapParams[sourceKey].(string) + if !ok || sourceParam == "" { return nil, fmt.Errorf("invalid or missing '%s' parameter", sourceKey) } @@ -168,7 +150,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para Files: []*firebaserules.File{ { Name: "firestore.rules", - Content: source, + Content: sourceParam, }, }, }, @@ -179,14 +161,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } // Call the test API - projectName := fmt.Sprintf("projects/%s", t.ProjectId) - response, err := t.RulesClient.Projects.Test(projectName, testRequest).Context(ctx).Do() + projectName := fmt.Sprintf("projects/%s", source.GetProjectId()) + response, err := source.FirebaseRulesClient().Projects.Test(projectName, testRequest).Context(ctx).Do() if err != nil { return nil, fmt.Errorf("failed to validate rules: %w", err) } // Process the response - result := t.processValidationResponse(response, source) + result := t.processValidationResponse(response, sourceParam) return result, nil } @@ -287,10 +269,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/http/http.go b/internal/tools/http/http.go index 4013d25d75..9e838b8b73 100644 --- a/internal/tools/http/http.go +++ b/internal/tools/http/http.go @@ -29,7 +29,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -50,6 +49,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + HttpDefaultHeaders() map[string]string + HttpBaseURL() string + HttpQueryParams() map[string]string + Client() *http.Client +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -81,7 +87,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } // verify the source is compatible - s, ok := rawS.(*httpsrc.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `http`", kind) } @@ -89,7 +95,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Combine Source and Tool headers. // In case of conflict, Tool header overrides Source header combinedHeaders := make(map[string]string) - maps.Copy(combinedHeaders, s.DefaultHeaders) + maps.Copy(combinedHeaders, s.HttpDefaultHeaders()) maps.Copy(combinedHeaders, cfg.Headers) // Create a slice for all parameters @@ -113,14 +119,11 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - BaseURL: s.BaseURL, - Headers: combinedHeaders, - DefaultQueryParams: s.QueryParams, - Client: s.Client, - AllParams: allParameters, - manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Headers: combinedHeaders, + AllParams: allParameters, + manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, }, nil } @@ -129,12 +132,8 @@ var _ tools.Tool = Tool{} type Tool struct { Config - BaseURL string `yaml:"baseURL"` - Headers map[string]string `yaml:"headers"` - DefaultQueryParams map[string]string `yaml:"defaultQueryParams"` - AllParams parameters.Parameters `yaml:"allParams"` - - Client *http.Client + Headers map[string]string `yaml:"headers"` + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -229,6 +228,11 @@ func getHeaders(headerParams parameters.Parameters, defaultHeaders map[string]st } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() // Calculate request body @@ -238,7 +242,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } // Calculate URL - urlString, err := getURL(t.BaseURL, t.Path, t.PathParams, t.QueryParams, t.DefaultQueryParams, paramsMap) + urlString, err := getURL(source.HttpBaseURL(), t.Path, t.PathParams, t.QueryParams, source.HttpQueryParams(), paramsMap) if err != nil { return nil, fmt.Errorf("error populating path parameters: %s", err) } @@ -256,7 +260,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } // Make request and fetch response - resp, err := t.Client.Do(req) + resp, err := source.Client().Do(req) if err != nil { return nil, fmt.Errorf("error making HTTP request: %s", err) } @@ -295,10 +299,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/looker/lookeradddashboardelement/lookeradddashboardelement.go b/internal/tools/looker/lookeradddashboardelement/lookeradddashboardelement.go index 9101573cb8..8c2417157b 100644 --- a/internal/tools/looker/lookeradddashboardelement/lookeradddashboardelement.go +++ b/internal/tools/looker/lookeradddashboardelement/lookeradddashboardelement.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := lookercommon.GetQueryParameters() dashIdParameter := parameters.NewStringParameter("dashboard_id", "The id of the dashboard where this tile will exist") @@ -109,12 +103,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -129,13 +119,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -148,6 +134,11 @@ var ( ) func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -167,12 +158,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para visConfig := paramsMap["vis_config"].(map[string]any) wq.VisConfig = &visConfig - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } - qresp, err := sdk.CreateQuery(*wq, "id", t.ApiSettings) + qresp, err := sdk.CreateQuery(*wq, "id", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making create query request: %w", err) } @@ -239,7 +230,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para Fields: &fields, } - resp, err := sdk.CreateDashboardElement(req, t.ApiSettings) + resp, err := sdk.CreateDashboardElement(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making create dashboard element request: %w", err) } @@ -264,14 +255,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookeradddashboardfilter/lookeradddashboardfilter.go b/internal/tools/looker/lookeradddashboardfilter/lookeradddashboardfilter.go index f1b9014574..bc01526aaa 100644 --- a/internal/tools/looker/lookeradddashboardfilter/lookeradddashboardfilter.go +++ b/internal/tools/looker/lookeradddashboardfilter/lookeradddashboardfilter.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := parameters.Parameters{} dashIdParameter := parameters.NewStringParameter("dashboard_id", "The id of the dashboard where this filter will exist") @@ -109,14 +103,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Name: cfg.Name, - Kind: kind, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, - Parameters: params, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -131,16 +119,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Name string `yaml:"name"` - Kind string `yaml:"kind"` - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - AuthRequired []string `yaml:"authRequired"` - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -148,6 +129,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -205,12 +191,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para req.Dimension = &dimension } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } - resp, err := sdk.CreateDashboardFilter(req, "name", t.ApiSettings) + resp, err := sdk.CreateDashboardFilter(req, "name", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making create dashboard filter request: %s", err) } @@ -239,10 +225,18 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerconversationalanalytics/lookerconversationalanalytics.go b/internal/tools/looker/lookerconversationalanalytics/lookerconversationalanalytics.go index c60ee650d8..ba09f4b6a6 100644 --- a/internal/tools/looker/lookerconversationalanalytics/lookerconversationalanalytics.go +++ b/internal/tools/looker/lookerconversationalanalytics/lookerconversationalanalytics.go @@ -26,7 +26,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookerds "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -56,12 +55,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T } type compatibleSource interface { - GetApiSettings() *rtl.ApiSettings GoogleCloudTokenSourceWithScope(ctx context.Context, scope string) (oauth2.TokenSource, error) GoogleCloudProject() string GoogleCloudLocation() string UseClientAuthorization() bool GetAuthTokenHeaderName() string + LookerApiSettings() *rtl.ApiSettings } // Structs for building the JSON payload @@ -124,11 +123,6 @@ type CAPayload struct { ClientIdEnum string `json:"clientIdEnum"` } -// validate compatible sources are still compatible -var _ compatibleSource = &lookerds.Source{} - -var compatibleSources = [...]string{lookerds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -155,7 +149,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } if s.GoogleCloudProject() == "" { @@ -196,16 +190,11 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - ApiSettings: s.GetApiSettings(), - Project: s.GoogleCloudProject(), - Location: s.GoogleCloudLocation(), - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - TokenSource: ts, - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + TokenSource: ts, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -215,15 +204,10 @@ var _ tools.Tool = Tool{} type Tool struct { Config - ApiSettings *rtl.ApiSettings - UseClientOAuth bool `yaml:"useClientOAuth"` - AuthTokenHeaderName string - Parameters parameters.Parameters `yaml:"parameters"` - Project string - Location string - TokenSource oauth2.TokenSource - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + TokenSource oauth2.TokenSource + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -231,8 +215,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + var tokenStr string - var err error // Get credentials for the API call // Use cloud-platform token source for Gemini Data Analytics API @@ -253,16 +241,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para ler := make([]LookerExploreReference, 0) for _, er := range exploreReferences { ler = append(ler, LookerExploreReference{ - LookerInstanceUri: t.ApiSettings.BaseUrl, + LookerInstanceUri: source.LookerApiSettings().BaseUrl, LookmlModel: er.(map[string]any)["model"].(string), Explore: er.(map[string]any)["explore"].(string), }) } oauth_creds := OAuthCredentials{} - if t.UseClientOAuth { + if source.UseClientAuthorization() { oauth_creds.Token = TokenBased{AccessToken: string(accessToken)} } else { - oauth_creds.Secret = SecretBased{ClientId: t.ApiSettings.ClientId, ClientSecret: t.ApiSettings.ClientSecret} + oauth_creds.Secret = SecretBased{ClientId: source.LookerApiSettings().ClientId, ClientSecret: source.LookerApiSettings().ClientSecret} } lers := LookerExploreReferences{ @@ -273,8 +261,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } // Construct URL, headers, and payload - projectID := t.Project - location := t.Location + projectID := source.GoogleCloudProject() + location := source.GoogleCloudLocation() caURL := fmt.Sprintf("https://geminidataanalytics.googleapis.com/v1beta/projects/%s/locations/%s:chat", url.PathEscape(projectID), url.PathEscape(location)) headers := map[string]string{ @@ -315,12 +303,16 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } -func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } // StreamMessage represents a single message object from the streaming API response. @@ -563,6 +555,10 @@ func appendMessage(messages []map[string]any, newMessage map[string]any) []map[s return append(messages, newMessage) } -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookercreateprojectfile/lookercreateprojectfile.go b/internal/tools/looker/lookercreateprojectfile/lookercreateprojectfile.go index 9d64a4f4fb..ddf53b94f4 100644 --- a/internal/tools/looker/lookercreateprojectfile/lookercreateprojectfile.go +++ b/internal/tools/looker/lookercreateprojectfile/lookercreateprojectfile.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -44,6 +43,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -61,18 +67,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - projectIdParameter := parameters.NewStringParameter("project_id", "The id of the project containing the files") filePathParameter := parameters.NewStringParameter("file_path", "The path of the file within the project") fileContentParameter := parameters.NewStringParameter("file_content", "The content of the file") @@ -90,12 +84,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -110,13 +100,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -124,7 +110,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -148,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para Content: fileContent, } - err = lookercommon.CreateProjectFile(sdk, projectId, req, t.ApiSettings) + err = lookercommon.CreateProjectFile(sdk, projectId, req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making create_project_file request: %s", err) } @@ -172,14 +163,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerdeleteprojectfile/lookerdeleteprojectfile.go b/internal/tools/looker/lookerdeleteprojectfile/lookerdeleteprojectfile.go index 86e7450dd7..5c20c95635 100644 --- a/internal/tools/looker/lookerdeleteprojectfile/lookerdeleteprojectfile.go +++ b/internal/tools/looker/lookerdeleteprojectfile/lookerdeleteprojectfile.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -44,6 +43,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -61,18 +67,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - projectIdParameter := parameters.NewStringParameter("project_id", "The id of the project containing the files") filePathParameter := parameters.NewStringParameter("file_path", "The path of the file within the project") params := parameters.Parameters{projectIdParameter, filePathParameter} @@ -91,12 +85,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -111,13 +101,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -125,7 +111,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -140,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("'file_path' must be a string, got %T", mapParams["file_path"]) } - err = lookercommon.DeleteProjectFile(sdk, projectId, filePath, t.ApiSettings) + err = lookercommon.DeleteProjectFile(sdk, projectId, filePath, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making delete_project_file request: %s", err) } @@ -164,14 +155,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerdevmode/lookerdevmode.go b/internal/tools/looker/lookerdevmode/lookerdevmode.go index e660f42a11..d33ed9c457 100644 --- a/internal/tools/looker/lookerdevmode/lookerdevmode.go +++ b/internal/tools/looker/lookerdevmode/lookerdevmode.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - devModeParameter := parameters.NewBooleanParameterWithDefault("devMode", true, "Whether to set Dev Mode.") params := parameters.Parameters{devModeParameter} @@ -89,19 +83,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired, }, - mcpManifest: mcpManifest, - ShowHiddenExplores: s.ShowHiddenExplores, + mcpManifest: mcpManifest, }, nil } @@ -110,14 +99,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest - ShowHiddenExplores bool + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -125,6 +109,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -135,7 +124,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("'devMode' must be a boolean, got %T", mapParams["devMode"]) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -148,7 +137,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para req := v4.WriteApiSession{ WorkspaceId: &devModeString, } - resp, err := sdk.UpdateSession(req, t.ApiSettings) + resp, err := sdk.UpdateSession(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error setting/resetting dev mode: %w", err) } @@ -169,14 +158,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergenerateembedurl/lookergenerateembedurl.go b/internal/tools/looker/lookergenerateembedurl/lookergenerateembedurl.go index e1c1bb4003..8dbc4a1557 100644 --- a/internal/tools/looker/lookergenerateembedurl/lookergenerateembedurl.go +++ b/internal/tools/looker/lookergenerateembedurl/lookergenerateembedurl.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -46,6 +45,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings + LookerSessionLength() int64 +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -63,18 +70,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - typeParameter := parameters.NewStringParameterWithDefault("type", "", "Type of Looker content to embed (ie. dashboards, looks, query-visualization)") idParameter := parameters.NewStringParameterWithDefault("id", "", "The ID of the content to embed.") params := parameters.Parameters{ @@ -94,19 +89,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired, }, - mcpManifest: mcpManifest, - SessionLength: s.SessionLength, + mcpManifest: mcpManifest, }, nil } @@ -115,15 +105,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - AuthRequired []string `yaml:"authRequired"` - Parameters parameters.Parameters - manifest tools.Manifest - mcpManifest tools.McpManifest - SessionLength int64 + Parameters parameters.Parameters + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -131,6 +115,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -147,16 +136,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para contentId_ptr = nil } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } forceLogoutLogin := true - + sessionLength := source.LookerSessionLength() req := v4.EmbedParams{ - TargetUrl: fmt.Sprintf("%s/embed/%s/%s", t.ApiSettings.BaseUrl, *embedType_ptr, *contentId_ptr), - SessionLength: &t.SessionLength, + TargetUrl: fmt.Sprintf("%s/embed/%s/%s", source.LookerApiSettings().BaseUrl, *embedType_ptr, *contentId_ptr), + SessionLength: &sessionLength, ForceLogoutLogin: &forceLogoutLogin, } logger.ErrorContext(ctx, "Making request %v", req) @@ -181,14 +170,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetconnectiondatabases/lookergetconnectiondatabases.go b/internal/tools/looker/lookergetconnectiondatabases/lookergetconnectiondatabases.go index 2e424a37de..c637b92260 100644 --- a/internal/tools/looker/lookergetconnectiondatabases/lookergetconnectiondatabases.go +++ b/internal/tools/looker/lookergetconnectiondatabases/lookergetconnectiondatabases.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -44,6 +43,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -61,18 +67,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - connParameter := parameters.NewStringParameter("conn", "The connection containing the databases.") params := parameters.Parameters{connParameter} @@ -88,12 +82,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -108,13 +98,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -122,17 +108,22 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() conn, ok := mapParams["conn"].(string) if !ok { return nil, fmt.Errorf("'conn' must be a string, got %T", mapParams["conn"]) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } - resp, err := sdk.ConnectionDatabases(conn, t.ApiSettings) + resp, err := sdk.ConnectionDatabases(conn, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_connection_databases request: %s", err) } @@ -153,14 +144,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetconnections/lookergetconnections.go b/internal/tools/looker/lookergetconnections/lookergetconnections.go index 821a88772e..75b4622a56 100644 --- a/internal/tools/looker/lookergetconnections/lookergetconnections.go +++ b/internal/tools/looker/lookergetconnections/lookergetconnections.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := parameters.Parameters{} annotations := cfg.Annotations @@ -88,12 +82,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -108,13 +98,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -122,16 +108,21 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } - resp, err := sdk.AllConnections("name, dialect(name), database, schema", t.ApiSettings) + resp, err := sdk.AllConnections("name, dialect(name), database, schema", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_connections request: %s", err) } @@ -147,7 +138,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if v.Schema != nil { vMap["schema"] = *v.Schema } - conn, err := sdk.ConnectionFeatures(*v.Name, "multiple_databases", t.ApiSettings) + conn, err := sdk.ConnectionFeatures(*v.Name, "multiple_databases", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_connection_features request: %s", err) } @@ -172,14 +163,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetconnectionschemas/lookergetconnectionschemas.go b/internal/tools/looker/lookergetconnectionschemas/lookergetconnectionschemas.go index 07d35ff375..6ceac7a205 100644 --- a/internal/tools/looker/lookergetconnectionschemas/lookergetconnectionschemas.go +++ b/internal/tools/looker/lookergetconnectionschemas/lookergetconnectionschemas.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -44,6 +43,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -61,18 +67,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - connParameter := parameters.NewStringParameter("conn", "The connection containing the schemas.") dbParameter := parameters.NewStringParameterWithRequired("db", "The optional database to search", false) params := parameters.Parameters{connParameter, dbParameter} @@ -89,12 +83,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -109,13 +99,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -123,6 +109,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() conn, ok := mapParams["conn"].(string) if !ok { @@ -130,7 +121,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } db, _ := mapParams["db"].(string) - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -140,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if db != "" { req.Database = &db } - resp, err := sdk.ConnectionSchemas(req, t.ApiSettings) + resp, err := sdk.ConnectionSchemas(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_connection_schemas request: %s", err) } @@ -159,14 +150,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetconnectiontablecolumns/lookergetconnectiontablecolumns.go b/internal/tools/looker/lookergetconnectiontablecolumns/lookergetconnectiontablecolumns.go index 9eab689fc0..4b1991cacf 100644 --- a/internal/tools/looker/lookergetconnectiontablecolumns/lookergetconnectiontablecolumns.go +++ b/internal/tools/looker/lookergetconnectiontablecolumns/lookergetconnectiontablecolumns.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - connParameter := parameters.NewStringParameter("conn", "The connection containing the tables.") dbParameter := parameters.NewStringParameterWithRequired("db", "The optional database to search", false) schemaParameter := parameters.NewStringParameter("schema", "The schema containing the tables.") @@ -92,12 +86,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -112,13 +102,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -126,6 +112,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -145,7 +136,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("'tables' must be a string, got %T", mapParams["tables"]) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -157,7 +148,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if db != "" { req.Database = &db } - resp, err := sdk.ConnectionColumns(req, t.ApiSettings) + resp, err := sdk.ConnectionColumns(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_connection_table_columns request: %s", err) } @@ -196,14 +187,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetconnectiontables/lookergetconnectiontables.go b/internal/tools/looker/lookergetconnectiontables/lookergetconnectiontables.go index 90771b9b63..1fd9df6515 100644 --- a/internal/tools/looker/lookergetconnectiontables/lookergetconnectiontables.go +++ b/internal/tools/looker/lookergetconnectiontables/lookergetconnectiontables.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - connParameter := parameters.NewStringParameter("conn", "The connection containing the tables.") dbParameter := parameters.NewStringParameterWithRequired("db", "The optional database to search", false) schemaParameter := parameters.NewStringParameter("schema", "The schema containing the tables.") @@ -91,12 +85,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -111,13 +101,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -125,6 +111,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -140,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("'schema' must be a string, got %T", mapParams["schema"]) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -151,7 +142,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if db != "" { req.Database = &db } - resp, err := sdk.ConnectionTables(req, t.ApiSettings) + resp, err := sdk.ConnectionTables(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_connection_tables request: %s", err) } @@ -187,14 +178,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetdashboards/lookergetdashboards.go b/internal/tools/looker/lookergetdashboards/lookergetdashboards.go index ff7bd265a8..6ef5be2f45 100644 --- a/internal/tools/looker/lookergetdashboards/lookergetdashboards.go +++ b/internal/tools/looker/lookergetdashboards/lookergetdashboards.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - titleParameter := parameters.NewStringParameterWithDefault("title", "", "The title of the dashboard.") descParameter := parameters.NewStringParameterWithDefault("desc", "", "The description of the dashboard.") limitParameter := parameters.NewIntParameterWithDefault("limit", 100, "The number of dashboards to fetch. Default 100") @@ -97,12 +91,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -117,13 +107,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -131,6 +117,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -149,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para limit := int64(paramsMap["limit"].(int)) offset := int64(paramsMap["offset"].(int)) - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -160,7 +151,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para Offset: &offset, } logger.ErrorContext(ctx, "Making request %v", req) - resp, err := sdk.SearchDashboards(req, t.ApiSettings) + resp, err := sdk.SearchDashboards(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_dashboards request: %s", err) } @@ -198,14 +189,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetdimensions/lookergetdimensions.go b/internal/tools/looker/lookergetdimensions/lookergetdimensions.go index b95187dd77..92c795dfb2 100644 --- a/internal/tools/looker/lookergetdimensions/lookergetdimensions.go +++ b/internal/tools/looker/lookergetdimensions/lookergetdimensions.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings + LookerShowHiddenFields() bool +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +69,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := lookercommon.GetFieldParameters() annotations := cfg.Annotations @@ -88,19 +83,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired, }, - mcpManifest: mcpManifest, - ShowHiddenFields: s.ShowHiddenFields, + mcpManifest: mcpManifest, }, nil } @@ -109,14 +99,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest - ShowHiddenFields bool + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -124,6 +109,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -133,7 +123,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error processing model or explore: %w", err) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -143,7 +133,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para ExploreName: *explore, Fields: &fields, } - resp, err := sdk.LookmlModelExplore(req, t.ApiSettings) + resp, err := sdk.LookmlModelExplore(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_dimensions request: %w", err) } @@ -152,7 +142,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error processing get_dimensions response: %w", err) } - data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Dimensions, t.ShowHiddenFields) + data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Dimensions, source.LookerShowHiddenFields()) if err != nil { return nil, fmt.Errorf("error extracting get_dimensions response: %w", err) } @@ -173,14 +163,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetexplores/lookergetexplores.go b/internal/tools/looker/lookergetexplores/lookergetexplores.go index 0c6e6d0ba2..75eaf9485a 100644 --- a/internal/tools/looker/lookergetexplores/lookergetexplores.go +++ b/internal/tools/looker/lookergetexplores/lookergetexplores.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings + LookerShowHiddenExplores() bool +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +69,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - modelParameter := parameters.NewStringParameter("model", "The model containing the explores.") params := parameters.Parameters{modelParameter} @@ -89,19 +84,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired, }, - mcpManifest: mcpManifest, - ShowHiddenExplores: s.ShowHiddenExplores, + mcpManifest: mcpManifest, }, nil } @@ -110,14 +100,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest - ShowHiddenExplores bool + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -125,6 +110,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -135,11 +125,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("'model' must be a string, got %T", mapParams["model"]) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } - resp, err := sdk.LookmlModel(model, "explores(name,description,label,group_label,hidden)", t.ApiSettings) + resp, err := sdk.LookmlModel(model, "explores(name,description,label,group_label,hidden)", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_explores request: %s", err) } @@ -147,7 +137,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para var data []any for _, v := range *resp.Explores { logger.DebugContext(ctx, "Got response element of %v\n", v) - if !t.ShowHiddenExplores && v.Hidden != nil && *v.Hidden { + if !source.LookerShowHiddenExplores() && v.Hidden != nil && *v.Hidden { continue } vMap := make(map[string]any) @@ -183,14 +173,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetfilters/lookergetfilters.go b/internal/tools/looker/lookergetfilters/lookergetfilters.go index 58fe004ea0..413874886b 100644 --- a/internal/tools/looker/lookergetfilters/lookergetfilters.go +++ b/internal/tools/looker/lookergetfilters/lookergetfilters.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings + LookerShowHiddenFields() bool +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +69,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := lookercommon.GetFieldParameters() annotations := cfg.Annotations @@ -88,19 +83,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired, }, - mcpManifest: mcpManifest, - ShowHiddenFields: s.ShowHiddenFields, + mcpManifest: mcpManifest, }, nil } @@ -109,14 +99,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest - ShowHiddenFields bool + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -124,6 +109,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -134,7 +124,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } fields := lookercommon.FiltersFields - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -143,7 +133,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para ExploreName: *explore, Fields: &fields, } - resp, err := sdk.LookmlModelExplore(req, t.ApiSettings) + resp, err := sdk.LookmlModelExplore(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_filters request: %w", err) } @@ -152,7 +142,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error processing get_filters response: %w", err) } - data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Filters, t.ShowHiddenFields) + data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Filters, source.LookerShowHiddenFields()) if err != nil { return nil, fmt.Errorf("error extracting get_filters response: %w", err) } @@ -173,14 +163,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetlooks/lookergetlooks.go b/internal/tools/looker/lookergetlooks/lookergetlooks.go index 00d2abb46e..b52bc059b4 100644 --- a/internal/tools/looker/lookergetlooks/lookergetlooks.go +++ b/internal/tools/looker/lookergetlooks/lookergetlooks.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - titleParameter := parameters.NewStringParameterWithDefault("title", "", "The title of the look.") descParameter := parameters.NewStringParameterWithDefault("desc", "", "The description of the look.") limitParameter := parameters.NewIntParameterWithDefault("limit", 100, "The number of looks to fetch. Default 100") @@ -97,12 +91,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -117,13 +107,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -131,6 +117,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -149,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para limit := int64(paramsMap["limit"].(int)) offset := int64(paramsMap["offset"].(int)) - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -159,7 +150,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para Limit: &limit, Offset: &offset, } - resp, err := sdk.SearchLooks(req, t.ApiSettings) + resp, err := sdk.SearchLooks(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_looks request: %s", err) } @@ -198,14 +189,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetmeasures/lookergetmeasures.go b/internal/tools/looker/lookergetmeasures/lookergetmeasures.go index 0a1f769a41..56b810126b 100644 --- a/internal/tools/looker/lookergetmeasures/lookergetmeasures.go +++ b/internal/tools/looker/lookergetmeasures/lookergetmeasures.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings + LookerShowHiddenFields() bool +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +69,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := lookercommon.GetFieldParameters() annotations := cfg.Annotations @@ -88,19 +83,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired, }, - mcpManifest: mcpManifest, - ShowHiddenFields: s.ShowHiddenFields, + mcpManifest: mcpManifest, }, nil } @@ -109,14 +99,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest - ShowHiddenFields bool + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -124,6 +109,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -134,7 +124,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } fields := lookercommon.MeasuresFields - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -143,7 +133,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para ExploreName: *explore, Fields: &fields, } - resp, err := sdk.LookmlModelExplore(req, t.ApiSettings) + resp, err := sdk.LookmlModelExplore(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_measures request: %w", err) } @@ -152,7 +142,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error processing get_measures response: %w", err) } - data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Measures, t.ShowHiddenFields) + data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Measures, source.LookerShowHiddenFields()) if err != nil { return nil, fmt.Errorf("error extracting get_measures response: %w", err) } @@ -173,14 +163,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetmodels/lookergetmodels.go b/internal/tools/looker/lookergetmodels/lookergetmodels.go index 496db583df..5c4f70f6b1 100644 --- a/internal/tools/looker/lookergetmodels/lookergetmodels.go +++ b/internal/tools/looker/lookergetmodels/lookergetmodels.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings + LookerShowHiddenModels() bool +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +69,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := parameters.Parameters{} annotations := cfg.Annotations @@ -88,19 +83,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired, }, - mcpManifest: mcpManifest, - ShowHiddenModels: s.ShowHiddenModels, + mcpManifest: mcpManifest, }, nil } @@ -109,14 +99,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest - ShowHiddenModels bool + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -124,16 +109,21 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) } excludeEmpty := false - excludeHidden := !t.ShowHiddenModels + excludeHidden := !source.LookerShowHiddenModels() includeInternal := true - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -142,7 +132,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para ExcludeHidden: &excludeHidden, IncludeInternal: &includeInternal, } - resp, err := sdk.AllLookmlModels(req, t.ApiSettings) + resp, err := sdk.AllLookmlModels(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_models request: %s", err) } @@ -175,14 +165,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetparameters/lookergetparameters.go b/internal/tools/looker/lookergetparameters/lookergetparameters.go index d9e6f807b7..2333cfb892 100644 --- a/internal/tools/looker/lookergetparameters/lookergetparameters.go +++ b/internal/tools/looker/lookergetparameters/lookergetparameters.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings + LookerShowHiddenFields() bool +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +69,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := lookercommon.GetFieldParameters() annotations := cfg.Annotations @@ -88,19 +83,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired, }, - mcpManifest: mcpManifest, - ShowHiddenFields: s.ShowHiddenFields, + mcpManifest: mcpManifest, }, nil } @@ -109,14 +99,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest - ShowHiddenFields bool + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -124,6 +109,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -134,7 +124,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } fields := lookercommon.ParametersFields - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -143,7 +133,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para ExploreName: *explore, Fields: &fields, } - resp, err := sdk.LookmlModelExplore(req, t.ApiSettings) + resp, err := sdk.LookmlModelExplore(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_parameters request: %w", err) } @@ -152,7 +142,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error processing get_parameters response: %w", err) } - data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Parameters, t.ShowHiddenFields) + data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Parameters, source.LookerShowHiddenFields()) if err != nil { return nil, fmt.Errorf("error extracting get_parameters response: %w", err) } @@ -173,14 +163,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetprojectfile/lookergetprojectfile.go b/internal/tools/looker/lookergetprojectfile/lookergetprojectfile.go index 25258a2af3..6d3fd015d3 100644 --- a/internal/tools/looker/lookergetprojectfile/lookergetprojectfile.go +++ b/internal/tools/looker/lookergetprojectfile/lookergetprojectfile.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - projectIdParameter := parameters.NewStringParameter("project_id", "The id of the project containing the files") filePathParameter := parameters.NewStringParameter("file_path", "The path of the file within the project") params := parameters.Parameters{projectIdParameter, filePathParameter} @@ -90,12 +84,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -110,13 +100,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -124,12 +110,17 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -144,7 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("'file_path' must be a string, got %T", mapParams["file_path"]) } - resp, err := lookercommon.GetProjectFileContent(sdk, projectId, filePath, t.ApiSettings) + resp, err := lookercommon.GetProjectFileContent(sdk, projectId, filePath, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_project_file request: %s", err) } @@ -169,14 +160,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetprojectfiles/lookergetprojectfiles.go b/internal/tools/looker/lookergetprojectfiles/lookergetprojectfiles.go index 0c805c7de6..78f3182246 100644 --- a/internal/tools/looker/lookergetprojectfiles/lookergetprojectfiles.go +++ b/internal/tools/looker/lookergetprojectfiles/lookergetprojectfiles.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - projectIdParameter := parameters.NewStringParameter("project_id", "The id of the project containing the files") params := parameters.Parameters{projectIdParameter} @@ -89,12 +83,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -109,13 +99,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -123,12 +109,17 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -139,7 +130,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("'project_id' must be a string, got %T", mapParams["project_id"]) } - resp, err := sdk.AllProjectFiles(projectId, "", t.ApiSettings) + resp, err := sdk.AllProjectFiles(projectId, "", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_project_files request: %s", err) } @@ -186,14 +177,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetprojects/lookergetprojects.go b/internal/tools/looker/lookergetprojects/lookergetprojects.go index c91ffec431..5756413662 100644 --- a/internal/tools/looker/lookergetprojects/lookergetprojects.go +++ b/internal/tools/looker/lookergetprojects/lookergetprojects.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := parameters.Parameters{} annotations := cfg.Annotations @@ -88,12 +82,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Parameters: params, - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -108,13 +98,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -122,17 +108,22 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } - resp, err := sdk.AllProjects("id,name", t.ApiSettings) + resp, err := sdk.AllProjects("id,name", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_models request: %s", err) } @@ -163,14 +154,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerhealthanalyze/lookerhealthanalyze.go b/internal/tools/looker/lookerhealthanalyze/lookerhealthanalyze.go index bc3422168f..0675b4dee5 100644 --- a/internal/tools/looker/lookerhealthanalyze/lookerhealthanalyze.go +++ b/internal/tools/looker/lookerhealthanalyze/lookerhealthanalyze.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -50,6 +49,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -67,16 +73,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - actionParameter := parameters.NewStringParameterWithRequired("action", "The analysis to run. Can be 'projects', 'models', or 'explores'.", true) projectParameter := parameters.NewStringParameterWithRequired("project", "The Looker project to analyze (optional).", false) modelParameter := parameters.NewStringParameterWithRequired("model", "The Looker model to analyze (optional).", false) @@ -104,12 +100,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -123,13 +115,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -137,12 +125,17 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -211,12 +204,16 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } -func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } // ================================================================================================================= @@ -566,6 +563,10 @@ func (t *analyzeTool) explores(ctx context.Context, model, explore string) ([]ma // END LOOKER HEALTH ANALYZE CORE LOGIC // ================================================================================================================= -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerhealthpulse/lookerhealthpulse.go b/internal/tools/looker/lookerhealthpulse/lookerhealthpulse.go index 73850edff0..45307b5011 100644 --- a/internal/tools/looker/lookerhealthpulse/lookerhealthpulse.go +++ b/internal/tools/looker/lookerhealthpulse/lookerhealthpulse.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -50,6 +49,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -67,16 +73,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - actionParameter := parameters.NewStringParameterWithRequired("action", "The health check to run. Can be either: `check_db_connections`, `check_dashboard_performance`,`check_dashboard_errors`,`check_explore_performance`,`check_schedule_failures`, or `check_legacy_features`", true) params := parameters.Parameters{ @@ -95,12 +91,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -114,13 +106,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -128,18 +116,23 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } pulseTool := &pulseTool{ - ApiSettings: t.ApiSettings, + ApiSettings: source.LookerApiSettings(), SdkClient: sdk, } @@ -153,7 +146,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para Action: action, } - result, err := pulseTool.RunPulse(ctx, pulseParams) + result, err := pulseTool.RunPulse(ctx, source, pulseParams) if err != nil { return nil, fmt.Errorf("error running pulse: %w", err) } @@ -175,12 +168,16 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } -func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } // ================================================================================================================= @@ -201,27 +198,27 @@ type pulseTool struct { SdkClient *v4.LookerSDK } -func (t *pulseTool) RunPulse(ctx context.Context, params PulseParams) (interface{}, error) { +func (t *pulseTool) RunPulse(ctx context.Context, source compatibleSource, params PulseParams) (interface{}, error) { switch params.Action { case "check_db_connections": - return t.checkDBConnections(ctx) + return t.checkDBConnections(ctx, source) case "check_dashboard_performance": - return t.checkDashboardPerformance(ctx) + return t.checkDashboardPerformance(ctx, source) case "check_dashboard_errors": - return t.checkDashboardErrors(ctx) + return t.checkDashboardErrors(ctx, source) case "check_explore_performance": - return t.checkExplorePerformance(ctx) + return t.checkExplorePerformance(ctx, source) case "check_schedule_failures": - return t.checkScheduleFailures(ctx) + return t.checkScheduleFailures(ctx, source) case "check_legacy_features": - return t.checkLegacyFeatures(ctx) + return t.checkLegacyFeatures(ctx, source) default: return nil, fmt.Errorf("unknown action: %s", params.Action) } } // Check DB connections and run tests -func (t *pulseTool) checkDBConnections(ctx context.Context) (interface{}, error) { +func (t *pulseTool) checkDBConnections(ctx context.Context, source compatibleSource) (interface{}, error) { logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -235,7 +232,7 @@ func (t *pulseTool) checkDBConnections(ctx context.Context) (interface{}, error) "looker__ilooker": {}, } - connections, err := t.SdkClient.AllConnections("", t.ApiSettings) + connections, err := t.SdkClient.AllConnections("", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error fetching connections: %w", err) } @@ -254,7 +251,7 @@ func (t *pulseTool) checkDBConnections(ctx context.Context) (interface{}, error) for _, conn := range filteredConnections { var errors []string // Test connection (simulate test_connection endpoint) - resp, err := t.SdkClient.TestConnection(*conn.Name, nil, t.ApiSettings) + resp, err := t.SdkClient.TestConnection(*conn.Name, nil, source.LookerApiSettings()) if err != nil { errors = append(errors, "API JSONDecode Error") } else { @@ -278,7 +275,7 @@ func (t *pulseTool) checkDBConnections(ctx context.Context) (interface{}, error) }, Limit: &limit, } - raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", t.ApiSettings) + raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", source.LookerApiSettings()) if err != nil { return nil, err } @@ -299,7 +296,7 @@ func (t *pulseTool) checkDBConnections(ctx context.Context) (interface{}, error) return results, nil } -func (t *pulseTool) checkDashboardPerformance(ctx context.Context) (interface{}, error) { +func (t *pulseTool) checkDashboardPerformance(ctx context.Context, source compatibleSource) (interface{}, error) { logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -320,7 +317,7 @@ func (t *pulseTool) checkDashboardPerformance(ctx context.Context) (interface{}, Sorts: &[]string{"query.count desc"}, Limit: &limit, } - raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", t.ApiSettings) + raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", source.LookerApiSettings()) if err != nil { return nil, err } @@ -331,7 +328,7 @@ func (t *pulseTool) checkDashboardPerformance(ctx context.Context) (interface{}, return dashboards, nil } -func (t *pulseTool) checkDashboardErrors(ctx context.Context) (interface{}, error) { +func (t *pulseTool) checkDashboardErrors(ctx context.Context, source compatibleSource) (interface{}, error) { logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -352,7 +349,7 @@ func (t *pulseTool) checkDashboardErrors(ctx context.Context) (interface{}, erro Sorts: &[]string{"history.query_run_count desc"}, Limit: &limit, } - raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", t.ApiSettings) + raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", source.LookerApiSettings()) if err != nil { return nil, err } @@ -363,7 +360,7 @@ func (t *pulseTool) checkDashboardErrors(ctx context.Context) (interface{}, erro return dashboards, nil } -func (t *pulseTool) checkExplorePerformance(ctx context.Context) (interface{}, error) { +func (t *pulseTool) checkExplorePerformance(ctx context.Context, source compatibleSource) (interface{}, error) { logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -382,7 +379,7 @@ func (t *pulseTool) checkExplorePerformance(ctx context.Context) (interface{}, e Sorts: &[]string{"history.average_runtime desc"}, Limit: &limit, } - raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", t.ApiSettings) + raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", source.LookerApiSettings()) if err != nil { return nil, err } @@ -393,7 +390,7 @@ func (t *pulseTool) checkExplorePerformance(ctx context.Context) (interface{}, e // Average query runtime query.Fields = &[]string{"history.average_runtime"} - rawAvg, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", t.ApiSettings) + rawAvg, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", source.LookerApiSettings()) if err != nil { return nil, err } @@ -408,7 +405,7 @@ func (t *pulseTool) checkExplorePerformance(ctx context.Context) (interface{}, e return explores, nil } -func (t *pulseTool) checkScheduleFailures(ctx context.Context) (interface{}, error) { +func (t *pulseTool) checkScheduleFailures(ctx context.Context, source compatibleSource) (interface{}, error) { logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -427,7 +424,7 @@ func (t *pulseTool) checkScheduleFailures(ctx context.Context) (interface{}, err Sorts: &[]string{"scheduled_job.count desc"}, Limit: &limit, } - raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", t.ApiSettings) + raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", source.LookerApiSettings()) if err != nil { return nil, err } @@ -438,14 +435,14 @@ func (t *pulseTool) checkScheduleFailures(ctx context.Context) (interface{}, err return schedules, nil } -func (t *pulseTool) checkLegacyFeatures(ctx context.Context) (interface{}, error) { +func (t *pulseTool) checkLegacyFeatures(ctx context.Context, source compatibleSource) (interface{}, error) { logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) } logger.InfoContext(ctx, "Test 6/6: Checking for enabled legacy features") - features, err := t.SdkClient.AllLegacyFeatures(t.ApiSettings) + features, err := t.SdkClient.AllLegacyFeatures(source.LookerApiSettings()) if err != nil { if strings.Contains(err.Error(), "Unsupported in Looker (Google Cloud core)") { return []map[string]string{{"Feature": "Unsupported in Looker (Google Cloud core)"}}, nil @@ -466,6 +463,10 @@ func (t *pulseTool) checkLegacyFeatures(ctx context.Context) (interface{}, error // END LOOKER HEALTH PULSE CORE LOGIC // ================================================================================================================= -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerhealthvacuum/lookerhealthvacuum.go b/internal/tools/looker/lookerhealthvacuum/lookerhealthvacuum.go index e0d963e608..d1d55a2fd0 100644 --- a/internal/tools/looker/lookerhealthvacuum/lookerhealthvacuum.go +++ b/internal/tools/looker/lookerhealthvacuum/lookerhealthvacuum.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -50,6 +49,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -67,16 +73,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - actionParameter := parameters.NewStringParameterWithRequired("action", "The vacuum action to run. Can be 'models', or 'explores'.", true) projectParameter := parameters.NewStringParameterWithDefault("project", "", "The Looker project to vacuum (optional).") modelParameter := parameters.NewStringParameterWithDefault("model", "", "The Looker model to vacuum (optional).") @@ -104,12 +100,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -123,13 +115,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -137,7 +125,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -189,12 +182,16 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } -func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } // ================================================================================================================= @@ -470,6 +467,10 @@ func (t *vacuumTool) getUsedExploreFields(ctx context.Context, model, explore st // END LOOKER HEALTH VACUUM CORE LOGIC // ================================================================================================================= -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookermakedashboard/lookermakedashboard.go b/internal/tools/looker/lookermakedashboard/lookermakedashboard.go index bc8c974935..2930d6e993 100644 --- a/internal/tools/looker/lookermakedashboard/lookermakedashboard.go +++ b/internal/tools/looker/lookermakedashboard/lookermakedashboard.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -47,6 +46,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -64,18 +70,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := parameters.Parameters{} titleParameter := parameters.NewStringParameter("title", "The title of the Dashboard") @@ -95,12 +89,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -115,13 +105,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -129,18 +115,23 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) } logger.DebugContext(ctx, "params = ", params) - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } mrespFields := "id,personal_folder_id" - mresp, err := sdk.Me(mrespFields, t.ApiSettings) + mresp, err := sdk.Me(mrespFields, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making me request: %s", err) } @@ -153,7 +144,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("user does not have a personal folder. cannot continue") } - dashs, err := sdk.FolderDashboards(*mresp.PersonalFolderId, "title", t.ApiSettings) + dashs, err := sdk.FolderDashboards(*mresp.PersonalFolderId, "title", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error getting existing dashboards in folder: %s", err) } @@ -172,13 +163,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para Description: &description, FolderId: mresp.PersonalFolderId, } - resp, err := sdk.CreateDashboard(wd, t.ApiSettings) + resp, err := sdk.CreateDashboard(wd, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making create dashboard request: %s", err) } logger.DebugContext(ctx, "resp = %v", resp) - setting, err := sdk.GetSetting("host_url", t.ApiSettings) + setting, err := sdk.GetSetting("host_url", source.LookerApiSettings()) if err != nil { logger.ErrorContext(ctx, "error getting settings: %s", err) } @@ -211,14 +202,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookermakelook/lookermakelook.go b/internal/tools/looker/lookermakelook/lookermakelook.go index b6387f56d5..7244c5d6fe 100644 --- a/internal/tools/looker/lookermakelook/lookermakelook.go +++ b/internal/tools/looker/lookermakelook/lookermakelook.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -47,6 +46,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -64,18 +70,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := lookercommon.GetQueryParameters() titleParameter := parameters.NewStringParameter("title", "The title of the Look") @@ -101,12 +95,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -121,13 +111,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -135,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -145,12 +136,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error building query request: %w", err) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } mrespFields := "id,personal_folder_id" - mresp, err := sdk.Me(mrespFields, t.ApiSettings) + mresp, err := sdk.Me(mrespFields, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making me request: %s", err) } @@ -159,7 +150,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para title := paramsMap["title"].(string) description := paramsMap["description"].(string) - looks, err := sdk.FolderLooks(*mresp.PersonalFolderId, "title", t.ApiSettings) + looks, err := sdk.FolderLooks(*mresp.PersonalFolderId, "title", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error getting existing looks in folder: %s", err) } @@ -177,7 +168,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para wq.VisConfig = &visConfig qrespFields := "id" - qresp, err := sdk.CreateQuery(*wq, qrespFields, t.ApiSettings) + qresp, err := sdk.CreateQuery(*wq, qrespFields, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making create query request: %s", err) } @@ -189,13 +180,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para QueryId: qresp.Id, FolderId: mresp.PersonalFolderId, } - resp, err := sdk.CreateLook(wlwq, "", t.ApiSettings) + resp, err := sdk.CreateLook(wlwq, "", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making create look request: %s", err) } logger.DebugContext(ctx, "resp = %v", resp) - setting, err := sdk.GetSetting("host_url", t.ApiSettings) + setting, err := sdk.GetSetting("host_url", source.LookerApiSettings()) if err != nil { logger.ErrorContext(ctx, "error getting settings: %s", err) } @@ -228,14 +219,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerquery/lookerquery.go b/internal/tools/looker/lookerquery/lookerquery.go index b5cb69635f..7f37d71c76 100644 --- a/internal/tools/looker/lookerquery/lookerquery.go +++ b/internal/tools/looker/lookerquery/lookerquery.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -46,6 +45,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -63,18 +69,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := lookercommon.GetQueryParameters() annotations := cfg.Annotations @@ -89,12 +83,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -109,13 +99,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -123,6 +109,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -131,11 +122,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, fmt.Errorf("error building WriteQuery request: %w", err) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } - resp, err := lookercommon.RunInlineQuery(ctx, sdk, wq, "json", t.ApiSettings) + resp, err := lookercommon.RunInlineQuery(ctx, sdk, wq, "json", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making query request: %s", err) } @@ -165,14 +156,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerquerysql/lookerquerysql.go b/internal/tools/looker/lookerquerysql/lookerquerysql.go index e93fc467e6..648894d8ed 100644 --- a/internal/tools/looker/lookerquerysql/lookerquerysql.go +++ b/internal/tools/looker/lookerquerysql/lookerquerysql.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := lookercommon.GetQueryParameters() annotations := cfg.Annotations @@ -88,12 +82,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -108,13 +98,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -122,6 +108,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -130,11 +121,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, fmt.Errorf("error building query request: %w", err) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } - resp, err := lookercommon.RunInlineQuery(ctx, sdk, wq, "sql", t.ApiSettings) + resp, err := lookercommon.RunInlineQuery(ctx, sdk, wq, "sql", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making query request: %s", err) } @@ -155,14 +146,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerqueryurl/lookerqueryurl.go b/internal/tools/looker/lookerqueryurl/lookerqueryurl.go index 70390c6c60..f76e0014a2 100644 --- a/internal/tools/looker/lookerqueryurl/lookerqueryurl.go +++ b/internal/tools/looker/lookerqueryurl/lookerqueryurl.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := lookercommon.GetQueryParameters() vizParameter := parameters.NewMapParameterWithDefault("vis_config", @@ -95,12 +89,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -115,13 +105,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -129,6 +115,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -143,12 +134,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para visConfig := paramsMap["vis_config"].(map[string]any) wq.VisConfig = &visConfig - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } respFields := "id,slug,share_url,expanded_share_url" - resp, err := sdk.CreateQuery(*wq, respFields, t.ApiSettings) + resp, err := sdk.CreateQuery(*wq, respFields, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making query request: %s", err) } @@ -184,14 +175,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerrundashboard/lookerrundashboard.go b/internal/tools/looker/lookerrundashboard/lookerrundashboard.go index 7b9c9b2797..6a27a77e3a 100644 --- a/internal/tools/looker/lookerrundashboard/lookerrundashboard.go +++ b/internal/tools/looker/lookerrundashboard/lookerrundashboard.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -47,6 +46,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -64,18 +70,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - dashboardidParameter := parameters.NewStringParameter("dashboard_id", "The id of the dashboard to run.") params := parameters.Parameters{ @@ -94,12 +88,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -114,13 +104,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -128,6 +114,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -137,11 +128,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para dashboard_id := paramsMap["dashboard_id"].(string) - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } - dashboard, err := sdk.Dashboard(dashboard_id, "", t.ApiSettings) + dashboard, err := sdk.Dashboard(dashboard_id, "", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error getting dashboard: %w", err) } @@ -157,7 +148,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para channels := make([]<-chan map[string]any, len(*dashboard.DashboardElements)) for i, element := range *dashboard.DashboardElements { - channels[i] = tileQueryWorker(ctx, sdk, t.ApiSettings, i, element) + channels[i] = tileQueryWorker(ctx, sdk, source.LookerApiSettings(), i, element) } for resp := range merge(channels...) { @@ -181,12 +172,16 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } -func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } func tileQueryWorker(ctx context.Context, sdk *v4.LookerSDK, options *rtl.ApiSettings, index int, element v4.DashboardElement) <-chan map[string]any { @@ -278,6 +273,10 @@ func merge(channels ...<-chan map[string]any) <-chan map[string]any { return out } -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerrunlook/lookerrunlook.go b/internal/tools/looker/lookerrunlook/lookerrunlook.go index 2c2fa9083b..9c7136b6c2 100644 --- a/internal/tools/looker/lookerrunlook/lookerrunlook.go +++ b/internal/tools/looker/lookerrunlook/lookerrunlook.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -46,6 +45,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -63,18 +69,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - lookidParameter := parameters.NewStringParameter("look_id", "The id of the look to run.") limitParameter := parameters.NewIntParameterWithDefault("limit", 500, "The row limit. Default 500") @@ -95,12 +89,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -115,13 +105,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -129,6 +115,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -140,12 +131,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para limit := int64(paramsMap["limit"].(int)) limitStr := fmt.Sprintf("%d", limit) - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } - look, err := sdk.Look(look_id, "", t.ApiSettings) + look, err := sdk.Look(look_id, "", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error getting look definition: %s", err) } @@ -161,7 +152,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para Limit: &limitStr, } - resp, err := lookercommon.RunInlineQuery(ctx, sdk, &wq, "json", t.ApiSettings) + resp, err := lookercommon.RunInlineQuery(ctx, sdk, &wq, "json", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making run_look request: %s", err) } @@ -194,10 +185,18 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerupdateprojectfile/lookerupdateprojectfile.go b/internal/tools/looker/lookerupdateprojectfile/lookerupdateprojectfile.go index c242545be5..2981f24270 100644 --- a/internal/tools/looker/lookerupdateprojectfile/lookerupdateprojectfile.go +++ b/internal/tools/looker/lookerupdateprojectfile/lookerupdateprojectfile.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -44,6 +43,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -61,18 +66,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - projectIdParameter := parameters.NewStringParameter("project_id", "The id of the project containing the files") filePathParameter := parameters.NewStringParameter("file_path", "The path of the file within the project") fileContentParameter := parameters.NewStringParameter("file_content", "The content of the file") @@ -92,12 +85,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -112,13 +101,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -126,7 +111,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -150,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para Content: fileContent, } - err = lookercommon.UpdateProjectFile(sdk, projectId, req, t.ApiSettings) + err = lookercommon.UpdateProjectFile(sdk, projectId, req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making update_project_file request: %s", err) } @@ -178,10 +168,18 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/mindsdb/mindsdbexecutesql/mindsdbexecutesql.go b/internal/tools/mindsdb/mindsdbexecutesql/mindsdbexecutesql.go index 2158ca33f8..51f2952177 100644 --- a/internal/tools/mindsdb/mindsdbexecutesql/mindsdbexecutesql.go +++ b/internal/tools/mindsdb/mindsdbexecutesql/mindsdbexecutesql.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/mindsdb" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -47,11 +46,6 @@ type compatibleSource interface { MindsDBPool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &mindsdb.Source{} - -var compatibleSources = [...]string{mindsdb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -68,18 +62,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") params := parameters.Parameters{sqlParameter} @@ -94,7 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Pool: s.MindsDBPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -106,9 +87,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Pool *sql.DB + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -118,13 +97,18 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"]) } - results, err := t.Pool.QueryContext(ctx, sql) + results, err := source.MindsDBPool().QueryContext(ctx, sql) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -193,10 +177,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mindsdb/mindsdbsql/mindsdbsql.go b/internal/tools/mindsdb/mindsdbsql/mindsdbsql.go index 5c07d00235..c247f4d4dc 100644 --- a/internal/tools/mindsdb/mindsdbsql/mindsdbsql.go +++ b/internal/tools/mindsdb/mindsdbsql/mindsdbsql.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/mindsdb" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -47,11 +46,6 @@ type compatibleSource interface { MindsDBPool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &mindsdb.Source{} - -var compatibleSources = [...]string{mindsdb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -100,7 +82,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.MindsDBPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -112,14 +93,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Pool *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -134,7 +118,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para sliceParams := newParams.AsSlice() // MindsDB now supports MySQL prepared statements natively - results, err := t.Pool.QueryContext(ctx, newStatement, sliceParams...) + results, err := source.MindsDBPool().QueryContext(ctx, newStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -203,14 +187,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go b/internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go index 5f5d3d0018..ccf7655ca3 100644 --- a/internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go +++ b/internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go @@ -20,7 +20,6 @@ import ( "slices" "github.com/goccy/go-yaml" - mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -45,6 +44,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + MongoClient() *mongo.Client +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -67,18 +70,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*mongosrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind) - } - // Create a slice for all parameters allParameters := slices.Concat(cfg.PipelineParams) @@ -96,7 +87,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, AllParams: allParameters, - database: s.Client.Database(cfg.Database), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -107,14 +97,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - database *mongo.Database + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() pipelineString, err := parameters.PopulateTemplateWithJSON("MongoDBAggregatePipeline", t.PipelinePayload, paramsMap) @@ -139,7 +132,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - cur, err := t.database.Collection(t.Collection).Aggregate(ctx, pipeline) + cur, err := source.MongoClient().Database(t.Database).Collection(t.Collection).Aggregate(ctx, pipeline) if err != nil { return nil, err } @@ -185,14 +178,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go b/internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go index 80c852fed6..566113b34b 100644 --- a/internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go +++ b/internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go @@ -20,7 +20,6 @@ import ( "slices" "github.com/goccy/go-yaml" - mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -46,6 +45,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + MongoClient() *mongo.Client +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -66,18 +69,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*mongosrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind) - } - // Create a slice for all parameters allParameters := slices.Concat(cfg.FilterParams) @@ -101,7 +92,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, AllParams: allParameters, - database: s.Client.Database(cfg.Database), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -112,14 +102,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - database *mongo.Database + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBDeleteManyFilter", t.FilterPayload, paramsMap) @@ -135,7 +128,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, err } - res, err := t.database.Collection(t.Collection).DeleteMany(ctx, filter, opts) + res, err := source.MongoClient().Database(t.Database).Collection(t.Collection).DeleteMany(ctx, filter, opts) if err != nil { return nil, err } @@ -164,14 +157,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go b/internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go index 0dd3cef756..6d16e5df70 100644 --- a/internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go +++ b/internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go @@ -19,7 +19,6 @@ import ( "slices" "github.com/goccy/go-yaml" - mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -45,6 +44,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + MongoClient() *mongo.Client +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -65,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*mongosrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind) - } - // Create a slice for all parameters allParameters := slices.Concat(cfg.FilterParams) @@ -100,7 +91,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, AllParams: allParameters, - database: s.Client.Database(cfg.Database), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -111,14 +101,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - database *mongo.Database + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBDeleteOneFilter", t.FilterPayload, paramsMap) @@ -134,7 +127,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, err } - res, err := t.database.Collection(t.Collection).DeleteOne(ctx, filter, opts) + res, err := source.MongoClient().Database(t.Database).Collection(t.Collection).DeleteOne(ctx, filter, opts) if err != nil { return nil, err } @@ -159,14 +152,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mongodb/mongodbfind/mongodbfind.go b/internal/tools/mongodb/mongodbfind/mongodbfind.go index fb67d7fb1f..88f3b25488 100644 --- a/internal/tools/mongodb/mongodbfind/mongodbfind.go +++ b/internal/tools/mongodb/mongodbfind/mongodbfind.go @@ -20,7 +20,6 @@ import ( "slices" "github.com/goccy/go-yaml" - mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/bson" @@ -47,6 +46,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + MongoClient() *mongo.Client +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -72,18 +75,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*mongosrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind) - } - // Create a slice for all parameters allParameters := slices.Concat(cfg.FilterParams, cfg.ProjectParams, cfg.SortParams) @@ -111,7 +102,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, AllParams: allParameters, - database: s.Client.Database(cfg.Database), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -122,9 +112,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - database *mongo.Database + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -169,6 +157,11 @@ func getOptions(ctx context.Context, sortParameters parameters.Parameters, proje } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBFindFilterString", t.FilterPayload, paramsMap) @@ -188,7 +181,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, err } - cur, err := t.database.Collection(t.Collection).Find(ctx, filter, opts) + cur, err := source.MongoClient().Database(t.Database).Collection(t.Collection).Find(ctx, filter, opts) if err != nil { return nil, err } @@ -230,14 +223,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mongodb/mongodbfindone/mongodbfindone.go b/internal/tools/mongodb/mongodbfindone/mongodbfindone.go index 3d49e65377..2e01d8e644 100644 --- a/internal/tools/mongodb/mongodbfindone/mongodbfindone.go +++ b/internal/tools/mongodb/mongodbfindone/mongodbfindone.go @@ -20,7 +20,6 @@ import ( "slices" "github.com/goccy/go-yaml" - mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -46,6 +45,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + MongoClient() *mongo.Client +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -68,18 +71,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*mongosrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind) - } - // Create a slice for all parameters allParameters := slices.Concat(cfg.FilterParams, cfg.ProjectParams) @@ -103,7 +94,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, AllParams: allParameters, - database: s.Client.Database(cfg.Database), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -114,14 +104,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - database *mongo.Database + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBFindOneFilterString", t.FilterPayload, paramsMap) @@ -150,7 +143,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, err } - res := t.database.Collection(t.Collection).FindOne(ctx, filter, opts) + res := source.MongoClient().Database(t.Database).Collection(t.Collection).FindOne(ctx, filter, opts) if res.Err() != nil { return nil, res.Err() } @@ -189,14 +182,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go b/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go index ea19e17901..f0cbf29d1d 100644 --- a/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go +++ b/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go @@ -20,7 +20,6 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/bson" @@ -46,6 +45,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + MongoClient() *mongo.Client +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -65,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*mongosrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind) - } - dataParam := parameters.NewStringParameterWithRequired(paramDataKey, "the JSON payload to insert, should be a JSON array of documents", true) allParameters := parameters.Parameters{dataParam} @@ -94,7 +85,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, PayloadParams: allParameters, - database: s.Client.Database(cfg.Database), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -106,31 +96,34 @@ var _ tools.Tool = Tool{} type Tool struct { Config PayloadParams parameters.Parameters - - database *mongo.Database - manifest tools.Manifest - mcpManifest tools.McpManifest + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + if len(params) == 0 { return nil, errors.New("no input found") } paramsMap := params.AsMap() - var jsonData, ok = paramsMap[paramDataKey].(string) + jsonData, ok := paramsMap[paramDataKey].(string) if !ok { return nil, errors.New("no input found") } var data = []any{} - err := bson.UnmarshalExtJSON([]byte(jsonData), t.Canonical, &data) + err = bson.UnmarshalExtJSON([]byte(jsonData), t.Canonical, &data) if err != nil { return nil, err } - res, err := t.database.Collection(t.Collection).InsertMany(ctx, data, options.InsertMany()) + res, err := source.MongoClient().Database(t.Database).Collection(t.Collection).InsertMany(ctx, data, options.InsertMany()) if err != nil { return nil, err } @@ -154,14 +147,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go b/internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go index 957dd47e7e..037a01dda7 100644 --- a/internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go +++ b/internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go @@ -20,7 +20,6 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/bson" @@ -46,6 +45,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + MongoClient() *mongo.Client +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -65,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*mongosrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind) - } - payloadParams := parameters.NewStringParameterWithRequired(dataParamsKey, "the JSON payload to insert, should be a JSON object", true) allParameters := parameters.Parameters{payloadParams} @@ -95,7 +86,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, PayloadParams: allParameters, - database: s.Client.Database(cfg.Database), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -107,29 +97,32 @@ var _ tools.Tool = Tool{} type Tool struct { Config PayloadParams parameters.Parameters `yaml:"payloadParams" validate:"required"` - - database *mongo.Database - manifest tools.Manifest - mcpManifest tools.McpManifest + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + if len(params) == 0 { return nil, errors.New("no input found") } // use the first, assume it's a string - var jsonData, ok = params[0].Value.(string) + jsonData, ok := params[0].Value.(string) if !ok { return nil, errors.New("no input found") } var data any - err := bson.UnmarshalExtJSON([]byte(jsonData), t.Canonical, &data) + err = bson.UnmarshalExtJSON([]byte(jsonData), t.Canonical, &data) if err != nil { return nil, err } - res, err := t.database.Collection(t.Collection).InsertOne(ctx, data, options.InsertOne()) + res, err := source.MongoClient().Database(t.Database).Collection(t.Collection).InsertOne(ctx, data, options.InsertOne()) if err != nil { return nil, err } @@ -153,14 +146,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go b/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go index 723e400e3e..1d38f1ff26 100644 --- a/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go +++ b/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go @@ -20,7 +20,6 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/bson" @@ -44,6 +43,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + MongoClient() *mongo.Client +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -68,18 +71,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*mongosrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind) - } - // Create a slice for all parameters allParameters := slices.Concat(cfg.FilterParams, cfg.UpdateParams) @@ -103,7 +94,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, AllParams: allParameters, - database: s.Client.Database(cfg.Database), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -114,14 +104,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - database *mongo.Database + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBUpdateManyFilter", t.FilterPayload, paramsMap) @@ -146,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to unmarshal update string: %w", err) } - res, err := t.database.Collection(t.Collection).UpdateMany(ctx, filter, update, options.Update().SetUpsert(t.Upsert)) + res, err := source.MongoClient().Database(t.Database).Collection(t.Collection).UpdateMany(ctx, filter, update, options.Update().SetUpsert(t.Upsert)) if err != nil { return nil, fmt.Errorf("error updating collection: %w", err) } @@ -170,14 +163,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go b/internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go index c656353ae6..397b521198 100644 --- a/internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go +++ b/internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go @@ -20,7 +20,6 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/bson" @@ -44,6 +43,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + MongoClient() *mongo.Client +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -69,18 +72,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*mongosrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind) - } - // Create a slice for all parameters allParameters := slices.Concat(cfg.FilterParams, cfg.UpdateParams) @@ -104,7 +95,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, AllParams: allParameters, - database: s.Client.Database(cfg.Database), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -115,14 +105,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters - - database *mongo.Database + AllParams parameters.Parameters manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBUpdateOneFilter", t.FilterPayload, paramsMap) @@ -147,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to unmarshal update string: %w", err) } - res, err := t.database.Collection(t.Collection).UpdateOne(ctx, filter, update, options.Update().SetUpsert(t.Upsert)) + res, err := source.MongoClient().Database(t.Database).Collection(t.Collection).UpdateOne(ctx, filter, update, options.Update().SetUpsert(t.Upsert)) if err != nil { return nil, fmt.Errorf("error updating collection: %w", err) } @@ -171,14 +164,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go b/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go index e2bbbb4cc2..ddfbdb089e 100644 --- a/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go +++ b/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go @@ -21,8 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql" - "github.com/googleapis/genai-toolbox/internal/sources/mssql" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/orderedmap" @@ -49,12 +47,6 @@ type compatibleSource interface { MSSQLDB() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &cloudsqlmssql.Source{} -var _ compatibleSource = &mssql.Source{} - -var compatibleSources = [...]string{cloudsqlmssql.SourceKind, mssql.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,18 +63,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") params := parameters.Parameters{sqlParameter} @@ -92,7 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Pool: s.MSSQLDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -104,14 +83,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Pool *sql.DB + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { @@ -125,7 +107,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - results, err := t.Pool.QueryContext(ctx, sql) + results, err := source.MSSQLDB().QueryContext(ctx, sql) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -183,14 +165,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mssql/mssqllisttables/mssqllisttables.go b/internal/tools/mssql/mssqllisttables/mssqllisttables.go index 03341132e2..29fbea4498 100644 --- a/internal/tools/mssql/mssqllisttables/mssqllisttables.go +++ b/internal/tools/mssql/mssqllisttables/mssqllisttables.go @@ -21,8 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql" - "github.com/googleapis/genai-toolbox/internal/sources/mssql" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -296,12 +294,6 @@ type compatibleSource interface { MSSQLDB() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &cloudsqlmssql.Source{} -var _ compatibleSource = &mssql.Source{} - -var compatibleSources = [...]string{cloudsqlmssql.SourceKind, mssql.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -318,18 +310,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("table_names", "", "Optional: A comma-separated list of table names. If empty, details for all tables will be listed."), parameters.NewStringParameterWithDefault("output_format", "detailed", "Optional: Use 'simple' for names only or 'detailed' for full info."), @@ -341,7 +321,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Db: s.MSSQLDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -353,14 +332,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Db *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() outputFormat, _ := paramsMap["output_format"].(string) @@ -373,7 +355,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para sql.Named("output_format", outputFormat), } - rows, err := t.Db.QueryContext(ctx, listTablesStatement, namedArgs...) + rows, err := source.MSSQLDB().QueryContext(ctx, listTablesStatement, namedArgs...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -428,14 +410,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mssql/mssqlsql/mssqlsql.go b/internal/tools/mssql/mssqlsql/mssqlsql.go index 7b18fabbcc..0e621b7417 100644 --- a/internal/tools/mssql/mssqlsql/mssqlsql.go +++ b/internal/tools/mssql/mssqlsql/mssqlsql.go @@ -22,8 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql" - "github.com/googleapis/genai-toolbox/internal/sources/mssql" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -48,12 +46,6 @@ type compatibleSource interface { MSSQLDB() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &cloudsqlmssql.Source{} -var _ compatibleSource = &mssql.Source{} - -var compatibleSources = [...]string{cloudsqlmssql.SourceKind, mssql.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -73,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -96,7 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Db: s.MSSQLDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -108,14 +87,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Db *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -140,7 +122,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - rows, err := t.Db.QueryContext(ctx, newStatement, namedArgs...) + rows, err := source.MSSQLDB().QueryContext(ctx, newStatement, namedArgs...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -198,14 +180,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go b/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go index 0a780b621e..5198602d70 100644 --- a/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go +++ b/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go @@ -21,9 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql" - "github.com/googleapis/genai-toolbox/internal/sources/mindsdb" - "github.com/googleapis/genai-toolbox/internal/sources/mysql" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -51,13 +48,6 @@ type compatibleSource interface { MySQLPool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &cloudsqlmysql.Source{} -var _ compatibleSource = &mysql.Source{} -var _ compatibleSource = &mindsdb.Source{} - -var compatibleSources = [...]string{cloudsqlmysql.SourceKind, mysql.SourceKind, mindsdb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -74,18 +64,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") params := parameters.Parameters{sqlParameter} @@ -95,7 +73,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Pool: s.MySQLPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -107,14 +84,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Pool *sql.DB + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { @@ -128,7 +108,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - results, err := t.Pool.QueryContext(ctx, sql) + results, err := source.MySQLPool().QueryContext(ctx, sql) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -197,14 +177,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(_ tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(_ tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go b/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go new file mode 100644 index 0000000000..3458a6ed83 --- /dev/null +++ b/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go @@ -0,0 +1,164 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysqlgetqueryplan + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/parameters" +) + +const kind string = "mysql-get-query-plan" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + MySQLPool() *sql.DB +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + sqlParameter := parameters.NewStringParameter("sql_statement", "The sql statement to explain.") + params := parameters.Parameters{sqlParameter} + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) + + // finish tool setup + t := Tool{ + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + paramsMap := params.AsMap() + sql, ok := paramsMap["sql_statement"].(string) + if !ok { + return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql_statement"]) + } + + // Log the query executed for debugging. + logger, err := util.LoggerFromContext(ctx) + if err != nil { + return nil, fmt.Errorf("error getting logger: %s", err) + } + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) + + query := fmt.Sprintf("EXPLAIN FORMAT=JSON %s", sql) + results, err := source.MySQLPool().QueryContext(ctx, query) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + var plan string + if results.Next() { + if err := results.Scan(&plan); err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + } else { + return nil, fmt.Errorf("no query plan returned") + } + + if err := results.Err(); err != nil { + return nil, fmt.Errorf("errors encountered during row iteration: %w", err) + } + + var out any + if err := json.Unmarshal([]byte(plan), &out); err != nil { + return nil, fmt.Errorf("failed to unmarshal query plan json: %w", err) + } + + return out, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.Parameters, data, claims) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(_ tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) GetAuthTokenHeaderName(_ tools.SourceProvider) (string, error) { + return "Authorization", nil +} diff --git a/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan_test.go b/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan_test.go new file mode 100644 index 0000000000..b06248dbaf --- /dev/null +++ b/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan_test.go @@ -0,0 +1,76 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysqlgetqueryplan_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlgetqueryplan" +) + +func TestParseFromYamlGetQueryPlan(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: mysql-get-query-plan + source: my-instance + description: some description + authRequired: + - my-google-auth-service + - other-auth-service + `, + want: server.ToolConfigs{ + "example_tool": mysqlgetqueryplan.Config{ + Name: "example_tool", + Kind: "mysql-get-query-plan", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{"my-google-auth-service", "other-auth-service"}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} diff --git a/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go b/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go index 0768a305b8..323d582d32 100644 --- a/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go +++ b/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go @@ -111,12 +111,6 @@ type compatibleSource interface { MySQLPool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &mysql.Source{} -var _ compatibleSource = &cloudsqlmysql.Source{} - -var compatibleSources = [...]string{mysql.SourceKind, cloudsqlmysql.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -138,11 +132,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - // verify the source is compatible - s, ok := rawS.(compatibleSource) + _, ok = rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } allParameters := parameters.Parameters{ @@ -165,7 +158,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ Config: cfg, - Pool: s.MySQLPool(), allParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -180,13 +172,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"parameters"` - Pool *sql.DB manifest tools.Manifest mcpManifest tools.McpManifest statement string } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() duration, ok := paramsMap["min_duration_secs"].(int) @@ -205,7 +201,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, t.statement)) - results, err := t.Pool.QueryContext(ctx, t.statement, duration, duration, limit) + results, err := source.MySQLPool().QueryContext(ctx, t.statement, duration, duration, limit) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -273,14 +269,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go b/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go index d0346a1a68..a0bc1b8f66 100644 --- a/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go +++ b/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go @@ -21,8 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql" - "github.com/googleapis/genai-toolbox/internal/sources/mysql" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -71,12 +69,6 @@ type compatibleSource interface { MySQLPool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &mysql.Source{} -var _ compatibleSource = &cloudsqlmysql.Source{} - -var compatibleSources = [...]string{mysql.SourceKind, cloudsqlmysql.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -93,18 +85,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("table_schema", "", "(Optional) The database where fragmentation check is to be executed. Check all tables visible to the current user if not specified"), parameters.NewStringParameterWithDefault("table_name", "", "(Optional) Name of the table to be checked. Check all tables visible to the current user if not specified."), @@ -116,7 +96,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ Config: cfg, - Pool: s.MySQLPool(), allParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -130,12 +109,16 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"parameters"` - Pool *sql.DB manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() table_schema, ok := paramsMap["table_schema"].(string) @@ -162,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, listTableFragmentationStatement)) - results, err := t.Pool.QueryContext(ctx, listTableFragmentationStatement, table_schema, table_schema, table_name, table_name, data_free_threshold_bytes, limit) + results, err := source.MySQLPool().QueryContext(ctx, listTableFragmentationStatement, table_schema, table_schema, table_name, table_name, data_free_threshold_bytes, limit) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -230,14 +213,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mysql/mysqllisttables/mysqllisttables.go b/internal/tools/mysql/mysqllisttables/mysqllisttables.go index ef4c9e6666..66928b75fa 100644 --- a/internal/tools/mysql/mysqllisttables/mysqllisttables.go +++ b/internal/tools/mysql/mysqllisttables/mysqllisttables.go @@ -21,8 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql" - "github.com/googleapis/genai-toolbox/internal/sources/mysql" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -201,12 +199,6 @@ type compatibleSource interface { MySQLPool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &cloudsqlmysql.Source{} -var _ compatibleSource = &mysql.Source{} - -var compatibleSources = [...]string{cloudsqlmysql.SourceKind, mysql.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -223,18 +215,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("table_names", "", "Optional: A comma-separated list of table names. If empty, details for all tables will be listed."), parameters.NewStringParameterWithDefault("output_format", "detailed", "Optional: Use 'simple' for names only or 'detailed' for full info."), @@ -246,7 +226,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.MySQLPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -260,12 +239,16 @@ type Tool struct { Config AllParams parameters.Parameters `yaml:"allParams"` - Pool *sql.DB manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() tableNames, ok := paramsMap["table_names"].(string) @@ -277,7 +260,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat) } - results, err := t.Pool.QueryContext(ctx, listTablesStatement, tableNames, outputFormat) + results, err := source.MySQLPool().QueryContext(ctx, listTablesStatement, tableNames, outputFormat) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -345,14 +328,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go b/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go index 4931b66a5d..522b180acd 100644 --- a/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go +++ b/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go @@ -21,8 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql" - "github.com/googleapis/genai-toolbox/internal/sources/mysql" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -72,12 +70,6 @@ type compatibleSource interface { MySQLPool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &mysql.Source{} -var _ compatibleSource = &cloudsqlmysql.Source{} - -var compatibleSources = [...]string{mysql.SourceKind, cloudsqlmysql.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -94,18 +86,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("table_schema", "", "(Optional) The database where the check is to be performed. Check all tables visible to the current user if not specified"), parameters.NewIntParameterWithDefault("limit", 50, "(Optional) Max rows to return, default is 50"), @@ -115,7 +95,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ Config: cfg, - Pool: s.MySQLPool(), allParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -129,12 +108,16 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"parameters"` - Pool *sql.DB manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() table_schema, ok := paramsMap["table_schema"].(string) @@ -153,7 +136,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, listTablesMissingUniqueIndexesStatement)) - results, err := t.Pool.QueryContext(ctx, listTablesMissingUniqueIndexesStatement, table_schema, table_schema, limit) + results, err := source.MySQLPool().QueryContext(ctx, listTablesMissingUniqueIndexesStatement, table_schema, table_schema, limit) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -221,14 +204,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mysql/mysqlsql/mysqlsql.go b/internal/tools/mysql/mysqlsql/mysqlsql.go index 4b3aed5a59..edf5f65db1 100644 --- a/internal/tools/mysql/mysqlsql/mysqlsql.go +++ b/internal/tools/mysql/mysqlsql/mysqlsql.go @@ -21,9 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql" - "github.com/googleapis/genai-toolbox/internal/sources/mindsdb" - "github.com/googleapis/genai-toolbox/internal/sources/mysql" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -49,13 +46,6 @@ type compatibleSource interface { MySQLPool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &cloudsqlmysql.Source{} -var _ compatibleSource = &mysql.Source{} -var _ compatibleSource = &mindsdb.Source{} - -var compatibleSources = [...]string{cloudsqlmysql.SourceKind, mysql.SourceKind, mindsdb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -75,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -98,7 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.MySQLPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -110,14 +87,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Pool *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -130,7 +110,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.Pool.QueryContext(ctx, newStatement, sliceParams...) + results, err := source.MySQLPool().QueryContext(ctx, newStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -198,14 +178,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/neo4j/neo4jcypher/neo4jcypher.go b/internal/tools/neo4j/neo4jcypher/neo4jcypher.go index 294ac5e90c..5f5c4ce05b 100644 --- a/internal/tools/neo4j/neo4jcypher/neo4jcypher.go +++ b/internal/tools/neo4j/neo4jcypher/neo4jcypher.go @@ -19,7 +19,6 @@ import ( "fmt" "github.com/goccy/go-yaml" - neo4jsc "github.com/googleapis/genai-toolbox/internal/sources/neo4j" "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/helpers" "github.com/neo4j/neo4j-go-driver/v5/neo4j" @@ -49,11 +48,6 @@ type compatibleSource interface { Neo4jDatabase() string } -// validate compatible sources are still compatible -var _ compatibleSource = &neo4jsc.Source{} - -var compatibleSources = [...]string{neo4jsc.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -72,25 +66,11 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, cfg.Parameters, nil) // finish tool setup t := Tool{ Config: cfg, - Driver: s.Neo4jDriver(), - Database: s.Neo4jDatabase(), manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -102,18 +82,20 @@ var _ tools.Tool = Tool{} type Tool struct { Config - - Driver neo4j.DriverWithContext - Database string manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() - config := neo4j.ExecuteQueryWithDatabase(t.Database) - results, err := neo4j.ExecuteQuery[*neo4j.EagerResult](ctx, t.Driver, t.Statement, paramsMap, + config := neo4j.ExecuteQueryWithDatabase(source.Neo4jDatabase()) + results, err := neo4j.ExecuteQuery[*neo4j.EagerResult](ctx, source.Neo4jDriver(), t.Statement, paramsMap, neo4j.EagerResultTransformer, config) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) @@ -149,14 +131,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go b/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go index 1f7fb8837e..0bf2b8f34e 100644 --- a/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go +++ b/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go @@ -20,7 +20,6 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - neo4jsc "github.com/googleapis/genai-toolbox/internal/sources/neo4j" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jexecutecypher/classifier" "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/helpers" @@ -49,11 +48,6 @@ type compatibleSource interface { Neo4jDatabase() string } -// validate compatible sources are still compatible -var _ compatibleSource = &neo4jsc.Source{} - -var compatibleSources = [...]string{neo4jsc.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,19 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - var s compatibleSource - s, ok = rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - cypherParameter := parameters.NewStringParameter("cypher", "The cypher to execute.") dryRunParameter := parameters.NewBooleanParameterWithDefault( "dry_run", @@ -99,8 +80,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Driver: s.Neo4jDriver(), - Database: s.Neo4jDatabase(), classifier: classifier.NewQueryClassifier(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -114,14 +93,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config Parameters parameters.Parameters `yaml:"parameters"` - Database string - Driver neo4j.DriverWithContext classifier *classifier.QueryClassifier manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() cypherStr, ok := paramsMap["cypher"].(string) if !ok { @@ -152,8 +134,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para cypherStr = "EXPLAIN " + cypherStr } - config := neo4j.ExecuteQueryWithDatabase(t.Database) - results, err := neo4j.ExecuteQuery(ctx, t.Driver, cypherStr, nil, + config := neo4j.ExecuteQueryWithDatabase(source.Neo4jDatabase()) + results, err := neo4j.ExecuteQuery(ctx, source.Neo4jDriver(), cypherStr, nil, neo4j.EagerResultTransformer, config) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) @@ -208,8 +190,8 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } // Recursive function to add plan children @@ -234,6 +216,6 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/neo4j/neo4jschema/neo4jschema.go b/internal/tools/neo4j/neo4jschema/neo4jschema.go index 6bef46a5e0..24b97cefb2 100644 --- a/internal/tools/neo4j/neo4jschema/neo4jschema.go +++ b/internal/tools/neo4j/neo4jschema/neo4jschema.go @@ -22,7 +22,6 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - neo4jsc "github.com/googleapis/genai-toolbox/internal/sources/neo4j" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/cache" "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/helpers" @@ -58,12 +57,6 @@ type compatibleSource interface { Neo4jDatabase() string } -// Statically verify that our compatible source implementation is valid. -var _ compatibleSource = &neo4jsc.Source{} - -// compatibleSources lists the kinds of sources that are compatible with this tool. -var compatibleSources = [...]string{neo4jsc.SourceKind} - // Config holds the configuration settings for the Neo4j schema tool. // These settings are typically read from a YAML file. type Config struct { @@ -85,17 +78,6 @@ func (cfg Config) ToolConfigKind() string { // Initialize sets up the tool with its dependencies and returns a ready-to-use Tool instance. func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // Verify that the specified source exists. - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // Verify the source is of a compatible kind. - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } params := parameters.Parameters{} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) @@ -109,8 +91,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Finish tool setup by creating the Tool instance. t := Tool{ Config: cfg, - Driver: s.Neo4jDriver(), - Database: s.Neo4jDatabase(), cache: cache.NewCache(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -125,10 +105,7 @@ var _ tools.Tool = Tool{} // It holds the Neo4j driver, database information, and a cache for the schema. type Tool struct { Config - Driver neo4j.DriverWithContext - Database string - cache *cache.Cache - + cache *cache.Cache manifest tools.Manifest mcpManifest tools.McpManifest } @@ -136,6 +113,11 @@ type Tool struct { // Invoke executes the tool's main logic: fetching the Neo4j schema. // It first checks the cache for a valid schema before extracting it from the database. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + // Check if a valid schema is already in the cache. if cachedSchema, ok := t.cache.Get("schema"); ok { if schema, ok := cachedSchema.(*types.SchemaInfo); ok { @@ -144,7 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } // If not cached, extract the schema from the database. - schema, err := t.extractSchema(ctx) + schema, err := t.extractSchema(ctx, source) if err != nil { return nil, fmt.Errorf("failed to extract database schema: %w", err) } @@ -176,16 +158,16 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } // checkAPOCProcedures verifies if essential APOC procedures are available in the database. // It returns true only if all required procedures are found. -func (t Tool) checkAPOCProcedures(ctx context.Context) (bool, error) { +func (t Tool) checkAPOCProcedures(ctx context.Context, source compatibleSource) (bool, error) { proceduresToCheck := []string{"apoc.meta.schema", "apoc.meta.cypher.types"} - session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database}) + session := source.Neo4jDriver().NewSession(ctx, neo4j.SessionConfig{DatabaseName: source.Neo4jDatabase()}) defer session.Close(ctx) // This query efficiently counts how many of the specified procedures exist. @@ -218,7 +200,7 @@ func (t Tool) checkAPOCProcedures(ctx context.Context) (bool, error) { // extractSchema orchestrates the concurrent extraction of different parts of the database schema. // It runs several extraction tasks in parallel for efficiency. -func (t Tool) extractSchema(ctx context.Context) (*types.SchemaInfo, error) { +func (t Tool) extractSchema(ctx context.Context, source compatibleSource) (*types.SchemaInfo, error) { schema := &types.SchemaInfo{} var mu sync.Mutex @@ -230,7 +212,7 @@ func (t Tool) extractSchema(ctx context.Context) (*types.SchemaInfo, error) { { name: "database-info", fn: func() error { - dbInfo, err := t.extractDatabaseInfo(ctx) + dbInfo, err := t.extractDatabaseInfo(ctx, source) if err != nil { return fmt.Errorf("failed to extract database info: %w", err) } @@ -244,7 +226,7 @@ func (t Tool) extractSchema(ctx context.Context) (*types.SchemaInfo, error) { name: "schema-extraction", fn: func() error { // Check if APOC procedures are available. - hasAPOC, err := t.checkAPOCProcedures(ctx) + hasAPOC, err := t.checkAPOCProcedures(ctx, source) if err != nil { return fmt.Errorf("failed to check APOC procedures: %w", err) } @@ -255,9 +237,9 @@ func (t Tool) extractSchema(ctx context.Context) (*types.SchemaInfo, error) { // Use APOC if available for a more detailed schema; otherwise, use native queries. if hasAPOC { - nodeLabels, relationships, stats, err = t.GetAPOCSchema(ctx) + nodeLabels, relationships, stats, err = t.GetAPOCSchema(ctx, source) } else { - nodeLabels, relationships, stats, err = t.GetSchemaWithoutAPOC(ctx, 100) + nodeLabels, relationships, stats, err = t.GetSchemaWithoutAPOC(ctx, source, 100) } if err != nil { return fmt.Errorf("failed to get schema: %w", err) @@ -274,7 +256,7 @@ func (t Tool) extractSchema(ctx context.Context) (*types.SchemaInfo, error) { { name: "constraints", fn: func() error { - constraints, err := t.extractConstraints(ctx) + constraints, err := t.extractConstraints(ctx, source) if err != nil { return fmt.Errorf("failed to extract constraints: %w", err) } @@ -287,7 +269,7 @@ func (t Tool) extractSchema(ctx context.Context) (*types.SchemaInfo, error) { { name: "indexes", fn: func() error { - indexes, err := t.extractIndexes(ctx) + indexes, err := t.extractIndexes(ctx, source) if err != nil { return fmt.Errorf("failed to extract indexes: %w", err) } @@ -329,7 +311,7 @@ func (t Tool) extractSchema(ctx context.Context) (*types.SchemaInfo, error) { } // GetAPOCSchema extracts schema information using the APOC library, which provides detailed metadata. -func (t Tool) GetAPOCSchema(ctx context.Context) ([]types.NodeLabel, []types.Relationship, *types.Statistics, error) { +func (t Tool) GetAPOCSchema(ctx context.Context, source compatibleSource) ([]types.NodeLabel, []types.Relationship, *types.Statistics, error) { var nodeLabels []types.NodeLabel var relationships []types.Relationship stats := &types.Statistics{ @@ -444,7 +426,7 @@ func (t Tool) GetAPOCSchema(ctx context.Context) ([]types.NodeLabel, []types.Rel fn func(session neo4j.SessionWithContext) error }) { defer wg.Done() - session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database}) + session := source.Neo4jDriver().NewSession(ctx, neo4j.SessionConfig{DatabaseName: source.Neo4jDatabase()}) defer session.Close(ctx) if err := task.fn(session); err != nil { handleError(fmt.Errorf("task %s failed: %w", task.name, err)) @@ -461,7 +443,7 @@ func (t Tool) GetAPOCSchema(ctx context.Context) ([]types.NodeLabel, []types.Rel // GetSchemaWithoutAPOC extracts schema information using native Cypher queries. // This serves as a fallback for databases without APOC installed. -func (t Tool) GetSchemaWithoutAPOC(ctx context.Context, sampleSize int) ([]types.NodeLabel, []types.Relationship, *types.Statistics, error) { +func (t Tool) GetSchemaWithoutAPOC(ctx context.Context, source compatibleSource, sampleSize int) ([]types.NodeLabel, []types.Relationship, *types.Statistics, error) { nodePropsMap := make(map[string]map[string]map[string]bool) relPropsMap := make(map[string]map[string]map[string]bool) nodeCounts := make(map[string]int64) @@ -609,7 +591,7 @@ func (t Tool) GetSchemaWithoutAPOC(ctx context.Context, sampleSize int) ([]types fn func(session neo4j.SessionWithContext) error }) { defer wg.Done() - session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database}) + session := source.Neo4jDriver().NewSession(ctx, neo4j.SessionConfig{DatabaseName: source.Neo4jDatabase()}) defer session.Close(ctx) if err := task.fn(session); err != nil { handleError(fmt.Errorf("task %s failed: %w", task.name, err)) @@ -627,8 +609,8 @@ func (t Tool) GetSchemaWithoutAPOC(ctx context.Context, sampleSize int) ([]types } // extractDatabaseInfo retrieves general information about the Neo4j database instance. -func (t Tool) extractDatabaseInfo(ctx context.Context) (*types.DatabaseInfo, error) { - session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database}) +func (t Tool) extractDatabaseInfo(ctx context.Context, source compatibleSource) (*types.DatabaseInfo, error) { + session := source.Neo4jDriver().NewSession(ctx, neo4j.SessionConfig{DatabaseName: source.Neo4jDatabase()}) defer session.Close(ctx) result, err := session.Run(ctx, "CALL dbms.components() YIELD name, versions, edition", nil) @@ -649,8 +631,8 @@ func (t Tool) extractDatabaseInfo(ctx context.Context) (*types.DatabaseInfo, err } // extractConstraints fetches all schema constraints from the database. -func (t Tool) extractConstraints(ctx context.Context) ([]types.Constraint, error) { - session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database}) +func (t Tool) extractConstraints(ctx context.Context, source compatibleSource) ([]types.Constraint, error) { + session := source.Neo4jDriver().NewSession(ctx, neo4j.SessionConfig{DatabaseName: source.Neo4jDatabase()}) defer session.Close(ctx) result, err := session.Run(ctx, "SHOW CONSTRAINTS", nil) @@ -678,8 +660,8 @@ func (t Tool) extractConstraints(ctx context.Context) ([]types.Constraint, error } // extractIndexes fetches all schema indexes from the database. -func (t Tool) extractIndexes(ctx context.Context) ([]types.Index, error) { - session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database}) +func (t Tool) extractIndexes(ctx context.Context, source compatibleSource) ([]types.Index, error) { + session := source.Neo4jDriver().NewSession(ctx, neo4j.SessionConfig{DatabaseName: source.Neo4jDatabase()}) defer session.Close(ctx) result, err := session.Run(ctx, "SHOW INDEXES", nil) @@ -711,6 +693,6 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/oceanbase/oceanbaseexecutesql/oceanbaseexecutesql.go b/internal/tools/oceanbase/oceanbaseexecutesql/oceanbaseexecutesql.go index 3dea5ec1ca..fa8d7a96a9 100644 --- a/internal/tools/oceanbase/oceanbaseexecutesql/oceanbaseexecutesql.go +++ b/internal/tools/oceanbase/oceanbaseexecutesql/oceanbaseexecutesql.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/oceanbase" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -39,11 +38,6 @@ type compatibleSource interface { OceanBasePool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &oceanbase.Source{} - -var compatibleSources = [...]string{oceanbase.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -68,18 +62,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") params := parameters.Parameters{sqlParameter} @@ -89,7 +71,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Pool: s.OceanBasePool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -101,22 +82,25 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Pool *sql.DB + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } // Invoke executes the SQL statement provided in the parameters. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + sliceParams := params.AsSlice() sqlStr, ok := sliceParams[0].(string) if !ok { return nil, fmt.Errorf("unable to get cast %s", sliceParams[0]) } - results, err := t.Pool.QueryContext(ctx, sqlStr) + results, err := source.OceanBasePool().QueryContext(ctx, sqlStr) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -189,14 +173,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/oceanbase/oceanbasesql/oceanbasesql.go b/internal/tools/oceanbase/oceanbasesql/oceanbasesql.go index c411f80c51..10a4dc17de 100644 --- a/internal/tools/oceanbase/oceanbasesql/oceanbasesql.go +++ b/internal/tools/oceanbase/oceanbasesql/oceanbasesql.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/oceanbase" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -39,11 +38,6 @@ type compatibleSource interface { OceanBasePool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &oceanbase.Source{} - -var compatibleSources = [...]string{oceanbase.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, fmt.Errorf("unable to process parameters: %w", err) @@ -94,7 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.OceanBasePool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -106,15 +87,18 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Pool *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } // Invoke executes the SQL statement with the provided parameters. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -127,7 +111,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.Pool.QueryContext(ctx, newStatement, sliceParams...) + results, err := source.OceanBasePool().QueryContext(ctx, newStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -200,14 +184,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go b/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go index 23d3a9b3de..447f9362e9 100644 --- a/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go +++ b/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go @@ -11,7 +11,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/oracle" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -37,11 +36,6 @@ type compatibleSource interface { OracleDB() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &oracle.Source{} - -var compatibleSources = [...]string{oracle.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -58,18 +52,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The SQL to execute.") params := parameters.Parameters{sqlParameter} @@ -79,7 +61,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Pool: s.OracleDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -91,14 +72,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Pool *sql.DB + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sqlParam, ok := paramsMap["sql"].(string) if !ok { @@ -110,9 +94,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, fmt.Errorf("error getting logger: %s", err) } - logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sqlParam)) + logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sqlParam) - results, err := t.Pool.QueryContext(ctx, sqlParam) + results, err := source.OracleDB().QueryContext(ctx, sqlParam) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -230,14 +214,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/oracle/oracleexecutesql/oracleexecutesql_test.go b/internal/tools/oracle/oracleexecutesql/oracleexecutesql_test.go new file mode 100644 index 0000000000..834d3d6981 --- /dev/null +++ b/internal/tools/oracle/oracleexecutesql/oracleexecutesql_test.go @@ -0,0 +1,82 @@ +// Copyright © 2025, Oracle and/or its affiliates. + +package oracleexecutesql_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/oracle/oracleexecutesql" +) + +func TestParseFromYamlOracleExecuteSql(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example with auth", + in: ` + tools: + run_adhoc_query: + kind: oracle-execute-sql + source: my-oracle-instance + description: Executes arbitrary SQL statements like INSERT or UPDATE. + authRequired: + - my-google-auth-service + `, + want: server.ToolConfigs{ + "run_adhoc_query": oracleexecutesql.Config{ + Name: "run_adhoc_query", + Kind: "oracle-execute-sql", + Source: "my-oracle-instance", + Description: "Executes arbitrary SQL statements like INSERT or UPDATE.", + AuthRequired: []string{"my-google-auth-service"}, + }, + }, + }, + { + desc: "example without authRequired", + in: ` + tools: + run_simple_update: + kind: oracle-execute-sql + source: db-dev + description: Runs a simple update operation. + `, + want: server.ToolConfigs{ + "run_simple_update": oracleexecutesql.Config{ + Name: "run_simple_update", + Kind: "oracle-execute-sql", + Source: "db-dev", + Description: "Runs a simple update operation.", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} diff --git a/internal/tools/oracle/oraclesql/oraclesql.go b/internal/tools/oracle/oraclesql/oraclesql.go index ff0cc07402..1ba87b47bd 100644 --- a/internal/tools/oracle/oraclesql/oraclesql.go +++ b/internal/tools/oracle/oraclesql/oraclesql.go @@ -11,7 +11,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/oracle" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -36,11 +35,6 @@ type compatibleSource interface { OracleDB() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &oracle.Source{} - -var compatibleSources = [...]string{oracle.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -60,18 +54,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, fmt.Errorf("error processing parameters: %w", err) @@ -83,7 +65,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - DB: s.OracleDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -95,14 +76,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - DB *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -120,7 +104,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } fmt.Printf("\n") - rows, err := t.DB.QueryContext(ctx, newStatement, sliceParams...) + rows, err := source.OracleDB().QueryContext(ctx, newStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -230,14 +214,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/oracle/oraclesql/oraclesql_test.go b/internal/tools/oracle/oraclesql/oraclesql_test.go new file mode 100644 index 0000000000..2ba0a7321c --- /dev/null +++ b/internal/tools/oracle/oraclesql/oraclesql_test.go @@ -0,0 +1,85 @@ +// Copyright © 2025, Oracle and/or its affiliates. +package oraclesql_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/oracle/oraclesql" +) + +func TestParseFromYamlOracleSql(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example with statement and auth", + in: ` + tools: + get_user_by_id: + kind: oracle-sql + source: my-oracle-instance + description: Retrieves user details by ID. + statement: "SELECT id, name, email FROM users WHERE id = :1" + authRequired: + - my-google-auth-service + `, + want: server.ToolConfigs{ + "get_user_by_id": oraclesql.Config{ + Name: "get_user_by_id", + Kind: "oracle-sql", + Source: "my-oracle-instance", + Description: "Retrieves user details by ID.", + Statement: "SELECT id, name, email FROM users WHERE id = :1", + AuthRequired: []string{"my-google-auth-service"}, + }, + }, + }, + { + desc: "example with parameters and template parameters", + in: ` + tools: + get_orders: + kind: oracle-sql + source: db-prod + description: Gets orders for a customer with optional filtering. + statement: "SELECT * FROM ${SCHEMA}.ORDERS WHERE customer_id = :customer_id AND status = :status" + `, + want: server.ToolConfigs{ + "get_orders": oraclesql.Config{ + Name: "get_orders", + Kind: "oracle-sql", + Source: "db-prod", + Description: "Gets orders for a customer with optional filtering.", + Statement: "SELECT * FROM ${SCHEMA}.ORDERS WHERE customer_id = :customer_id AND status = :status", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} diff --git a/internal/tools/postgres/postgresdatabaseoverview/postgresdatabaseoverview.go b/internal/tools/postgres/postgresdatabaseoverview/postgresdatabaseoverview.go index e7fde8a842..4e8a0a29ce 100644 --- a/internal/tools/postgres/postgresdatabaseoverview/postgresdatabaseoverview.go +++ b/internal/tools/postgres/postgresdatabaseoverview/postgresdatabaseoverview.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -62,13 +59,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -85,18 +75,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{} if cfg.Description == "" { cfg.Description = "Fetches the current state of the PostgreSQL server, returning the version, whether it's a replica, uptime duration, maximum connection limit, number of current connections, number of active connections, and the percentage of connections in use." @@ -107,7 +85,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, allParams: allParameters, - pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: allParameters.Manifest(), @@ -123,7 +100,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } @@ -133,6 +109,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -141,7 +122,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, databaseOverviewStatement, sliceParams...) + results, err := source.PostgresPool().Query(ctx, databaseOverviewStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -186,10 +167,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go b/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go index 56a204d4a2..73afd2a6ee 100644 --- a/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go +++ b/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/orderedmap" @@ -50,13 +47,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -73,18 +63,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") params := parameters.Parameters{sqlParameter} @@ -94,7 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Pool: s.PostgresPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -106,14 +83,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Pool *pgxpool.Pool + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { @@ -126,7 +106,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - results, err := t.Pool.Query(ctx, sql) + results, err := source.PostgresPool().Query(ctx, sql) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -170,14 +150,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgresgetcolumncardinality/postgresgetcolumncardinality.go b/internal/tools/postgres/postgresgetcolumncardinality/postgresgetcolumncardinality.go index 150d61f86b..f96654fbc6 100644 --- a/internal/tools/postgres/postgresgetcolumncardinality/postgresgetcolumncardinality.go +++ b/internal/tools/postgres/postgresgetcolumncardinality/postgresgetcolumncardinality.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -67,13 +64,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -90,18 +80,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("schema_name", "public", "Optional: The schema name in which the table is present."), parameters.NewStringParameterWithRequired("table_name", "Required: The table name in which the column is present.", true), @@ -117,11 +95,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - name: cfg.Name, - kind: cfg.Kind, - authRequired: cfg.AuthRequired, - allParams: allParameters, - pool: s.PostgresPool(), + Config: cfg, + allParams: allParameters, manifest: tools.Manifest{ Description: cfg.Description, Parameters: paramManifest, @@ -136,13 +111,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - name string `yaml:"name"` - kind string `yaml:"kind"` - authRequired []string `yaml:"authRequired"` - allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool - manifest tools.Manifest - mcpManifest tools.McpManifest + allParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -150,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -158,7 +134,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, getColumnCardinality, sliceParams...) + results, err := source.PostgresPool().Query(ctx, getColumnCardinality, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -199,13 +175,13 @@ func (t Tool) McpManifest() tools.McpManifest { } func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.authRequired, verifiedAuthServices) + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistactivequeries/postgreslistactivequeries.go b/internal/tools/postgres/postgreslistactivequeries/postgreslistactivequeries.go index 9531041b9e..6ad5bff569 100644 --- a/internal/tools/postgres/postgreslistactivequeries/postgreslistactivequeries.go +++ b/internal/tools/postgres/postgreslistactivequeries/postgreslistactivequeries.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -71,13 +68,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -94,18 +84,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("min_duration", "1 minute", "Optional: Only show queries running at least this long (e.g., '1 minute', '1 second', '2 seconds')."), parameters.NewStringParameterWithDefault("exclude_application_names", "", "Optional: A comma-separated list of application names to exclude from the query results. This is useful for filtering out queries from specific applications (e.g., 'psql', 'pgAdmin', 'DBeaver'). The match is case-sensitive. Whitespace around commas and names is automatically handled. If this parameter is omitted, no applications are excluded."), @@ -118,7 +96,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, allParams: allParameters, - pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: paramManifest, @@ -135,12 +112,16 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -149,7 +130,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, listActiveQueriesStatement, sliceParams...) + results, err := source.PostgresPool().Query(ctx, listActiveQueriesStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -194,14 +175,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistavailableextensions/postgreslistavailableextensions.go b/internal/tools/postgres/postgreslistavailableextensions/postgreslistavailableextensions.go index e74d1709a2..1440509cbb 100644 --- a/internal/tools/postgres/postgreslistavailableextensions/postgreslistavailableextensions.go +++ b/internal/tools/postgres/postgreslistavailableextensions/postgreslistavailableextensions.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -58,13 +55,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -81,25 +71,12 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - params := parameters.Parameters{} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) // finish tool setup t := Tool{ Config: cfg, - Pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -115,13 +92,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - results, err := t.Pool.Query(ctx, listAvailableExtensionsQuery) + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + results, err := source.PostgresPool().Query(ctx, listAvailableExtensionsQuery) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -165,14 +146,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats.go b/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats.go index 31edc08f11..27cc16c1ed 100644 --- a/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats.go +++ b/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -115,13 +112,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -138,18 +128,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("database_name", "", "Optional: A specific database name pattern to search for."), parameters.NewBooleanParameterWithDefault("include_templates", false, "Optional: Whether to include template databases in the results."), @@ -188,7 +166,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, allParams: allParameters, - pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: allParameters.Manifest(), @@ -204,12 +181,16 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -218,7 +199,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, listDatabaseStats, sliceParams...) + results, err := source.PostgresPool().Query(ctx, listDatabaseStats, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -263,14 +244,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistindexes/postgreslistindexes.go b/internal/tools/postgres/postgreslistindexes/postgreslistindexes.go index 34274908ec..0f85a0e46c 100644 --- a/internal/tools/postgres/postgreslistindexes/postgreslistindexes.go +++ b/internal/tools/postgres/postgreslistindexes/postgreslistindexes.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -94,13 +91,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -117,18 +107,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("schema_name", "", "Optional: a text to filter results by schema name. The input is used within a LIKE clause."), parameters.NewStringParameterWithDefault("table_name", "", "Optional: a text to filter results by table name. The input is used within a LIKE clause."), @@ -146,7 +124,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, allParams: allParameters, - pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: allParameters.Manifest(), @@ -162,7 +139,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } @@ -172,6 +148,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -180,7 +161,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, listIndexesStatement, sliceParams...) + results, err := source.PostgresPool().Query(ctx, listIndexesStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -225,10 +206,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistinstalledextensions/postgreslistinstalledextensions.go b/internal/tools/postgres/postgreslistinstalledextensions/postgreslistinstalledextensions.go index 2ea41d9204..effa306f46 100644 --- a/internal/tools/postgres/postgreslistinstalledextensions/postgreslistinstalledextensions.go +++ b/internal/tools/postgres/postgreslistinstalledextensions/postgreslistinstalledextensions.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -69,13 +66,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -92,25 +82,12 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - params := parameters.Parameters{} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) // finish tool setup t := Tool{ Config: cfg, - Pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -126,13 +103,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - results, err := t.Pool.Query(ctx, listAvailableExtensionsQuery) + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + results, err := source.PostgresPool().Query(ctx, listAvailableExtensionsQuery) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -176,14 +157,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go b/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go index 68f6d566fe..881962e2be 100644 --- a/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go +++ b/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -69,13 +66,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -92,18 +82,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{} paramManifest := allParameters.Manifest() @@ -115,11 +93,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - name: cfg.Name, - kind: cfg.Kind, - authRequired: cfg.AuthRequired, - allParams: allParameters, - pool: s.PostgresPool(), + Config: cfg, + allParams: allParameters, manifest: tools.Manifest{ Description: cfg.Description, Parameters: paramManifest, @@ -134,13 +109,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - name string `yaml:"name"` - kind string `yaml:"kind"` - authRequired []string `yaml:"authRequired"` - allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool - manifest tools.Manifest - mcpManifest tools.McpManifest + allParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -148,6 +119,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -156,7 +132,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, listLocks, sliceParams...) + results, err := source.PostgresPool().Query(ctx, listLocks, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -193,13 +169,13 @@ func (t Tool) McpManifest() tools.McpManifest { } func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.authRequired, verifiedAuthServices) + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings.go b/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings.go index ae5b9ff2dd..05fccc3d6e 100644 --- a/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings.go +++ b/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -67,13 +64,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -90,18 +80,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("setting_name", "", "Optional: A specific configuration parameter name pattern to search for."), parameters.NewIntParameterWithDefault("limit", 50, "Optional: The maximum number of rows to return."), @@ -116,7 +94,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, allParams: allParameters, - pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: allParameters.Manifest(), @@ -132,12 +109,15 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -146,7 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, listPgSettingsStatement, sliceParams...) + results, err := source.PostgresPool().Query(ctx, listPgSettingsStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -191,14 +171,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables.go b/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables.go index c280cd5ff0..9b1d48fdea 100644 --- a/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables.go +++ b/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -78,13 +75,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -101,18 +91,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("table_names", "", "Optional: Filters by a comma-separated list of table names."), parameters.NewStringParameterWithDefault("publication_names", "", "Optional: Filters by a comma-separated list of publication names."), @@ -129,7 +107,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, allParams: allParameters, - pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: allParameters.Manifest(), @@ -145,12 +122,16 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -159,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, listPublicationTablesStatement, sliceParams...) + results, err := source.PostgresPool().Query(ctx, listPublicationTablesStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -203,14 +184,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistquerystats/postgreslistquerystats.go b/internal/tools/postgres/postgreslistquerystats/postgreslistquerystats.go index 1544eccefb..e2a26e496b 100644 --- a/internal/tools/postgres/postgreslistquerystats/postgreslistquerystats.go +++ b/internal/tools/postgres/postgreslistquerystats/postgreslistquerystats.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -68,13 +65,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -91,18 +81,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("database_name", "", "Optional: The database name to list query stats for."), parameters.NewIntParameterWithDefault("limit", 50, "Optional: The maximum number of results to return. Defaults to 50."), @@ -117,11 +95,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - name: cfg.Name, - kind: cfg.Kind, - authRequired: cfg.AuthRequired, - allParams: allParameters, - pool: s.PostgresPool(), + Config: cfg, + allParams: allParameters, manifest: tools.Manifest{ Description: cfg.Description, Parameters: paramManifest, @@ -136,13 +111,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - name string `yaml:"name"` - kind string `yaml:"kind"` - authRequired []string `yaml:"authRequired"` - allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool - manifest tools.Manifest - mcpManifest tools.McpManifest + allParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -150,6 +121,10 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -158,7 +133,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, listQueryStats, sliceParams...) + results, err := source.PostgresPool().Query(ctx, listQueryStats, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -199,13 +174,13 @@ func (t Tool) McpManifest() tools.McpManifest { } func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.authRequired, verifiedAuthServices) + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistroles/postgreslistroles.go b/internal/tools/postgres/postgreslistroles/postgreslistroles.go index 3e0f59dd32..160aebb31a 100644 --- a/internal/tools/postgres/postgreslistroles/postgreslistroles.go +++ b/internal/tools/postgres/postgreslistroles/postgreslistroles.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -90,13 +87,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -113,18 +103,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("role_name", "", "Optional: a text to filter results by role name. The input is used within a LIKE clause."), parameters.NewIntParameterWithDefault("limit", 50, "Optional: The maximum number of rows to return. Default is 10"), @@ -140,7 +118,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, allParams: allParameters, - pool: s.PostgresPool(), manifest: tools.Manifest{ Description: description, Parameters: allParameters.Manifest(), @@ -156,7 +133,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } @@ -166,6 +142,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -174,7 +155,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, listRolesStatement, sliceParams...) + results, err := source.PostgresPool().Query(ctx, listRolesStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -219,10 +200,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistschemas/postgreslistschemas.go b/internal/tools/postgres/postgreslistschemas/postgreslistschemas.go index 66ee3b2596..729a4af1b4 100644 --- a/internal/tools/postgres/postgreslistschemas/postgreslistschemas.go +++ b/internal/tools/postgres/postgreslistschemas/postgreslistschemas.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -102,13 +99,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -125,18 +115,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("schema_name", "", "Optional: A specific schema name pattern to search for."), parameters.NewStringParameterWithDefault("owner", "", "Optional: A specific schema owner name pattern to search for."), @@ -152,7 +130,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, allParams: allParameters, - pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: allParameters.Manifest(), @@ -168,12 +145,16 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -182,7 +163,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, listSchemasStatement, sliceParams...) + results, err := source.PostgresPool().Query(ctx, listSchemasStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -227,14 +208,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistsequences/postgreslistsequences.go b/internal/tools/postgres/postgreslistsequences/postgreslistsequences.go index ba2d3b53c9..a8877ab6f7 100644 --- a/internal/tools/postgres/postgreslistsequences/postgreslistsequences.go +++ b/internal/tools/postgres/postgreslistsequences/postgreslistsequences.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -68,13 +65,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -91,18 +81,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("schema_name", "", "Optional: A specific schema name pattern to search for."), parameters.NewStringParameterWithDefault("sequence_name", "", "Optional: A specific sequence name pattern to search for."), @@ -118,7 +96,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, allParams: allParameters, - pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: allParameters.Manifest(), @@ -134,7 +111,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } @@ -144,6 +120,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -152,7 +133,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, listSequencesStatement, sliceParams...) + results, err := source.PostgresPool().Query(ctx, listSequencesStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -197,10 +178,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslisttables/postgreslisttables.go b/internal/tools/postgres/postgreslisttables/postgreslisttables.go index 5e949a755e..264983edb6 100644 --- a/internal/tools/postgres/postgreslisttables/postgreslisttables.go +++ b/internal/tools/postgres/postgreslisttables/postgreslisttables.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -126,13 +123,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -149,18 +139,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("table_names", "", "Optional: A comma-separated list of table names. If empty, details for all tables will be listed."), parameters.NewStringParameterWithDefault("output_format", "detailed", "Optional: Use 'simple' for names only or 'detailed' for full info."), @@ -171,7 +149,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.PostgresPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -184,14 +161,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Pool *pgxpool.Pool + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() tableNames, ok := paramsMap["table_names"].(string) @@ -203,7 +183,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat) } - results, err := t.Pool.Query(ctx, listTablesStatement, tableNames, outputFormat) + results, err := source.PostgresPool().Query(ctx, listTablesStatement, tableNames, outputFormat) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -247,14 +227,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces.go b/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces.go index b150f1ebf2..8e2d0e700d 100644 --- a/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces.go +++ b/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -74,13 +71,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -97,18 +87,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("tablespace_name", "", "Optional: a text to filter results by tablespace name. The input is used within a LIKE clause."), parameters.NewIntParameterWithDefault("limit", 50, "Optional: The maximum number of rows to return."), @@ -123,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, allParams: allParameters, - pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: allParameters.Manifest(), @@ -139,7 +116,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } @@ -149,6 +125,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() tablespaceName, ok := paramsMap["tablespace_name"].(string) @@ -160,7 +141,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid 'limit' parameter; expected an integer") } - results, err := t.pool.Query(ctx, listTableSpacesStatement, tablespaceName, limit) + results, err := source.PostgresPool().Query(ctx, listTableSpacesStatement, tablespaceName, limit) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -204,10 +185,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats.go b/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats.go index e5700a2629..69a953e654 100644 --- a/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats.go +++ b/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -95,13 +92,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -118,18 +108,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("schema_name", "public", "Optional: A specific schema name to filter by"), parameters.NewStringParameterWithRequired("table_name", "Optional: A specific table name to filter by", false), @@ -155,11 +133,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - name: cfg.Name, - kind: cfg.Kind, - authRequired: cfg.AuthRequired, - allParams: allParameters, - pool: s.PostgresPool(), + Config: cfg, + allParams: allParameters, manifest: tools.Manifest{ Description: cfg.Description, Parameters: paramManifest, @@ -174,13 +149,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - name string `yaml:"name"` - kind string `yaml:"kind"` - authRequired []string `yaml:"authRequired"` - allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool - manifest tools.Manifest - mcpManifest tools.McpManifest + allParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -188,6 +159,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -196,7 +172,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, listTableStats, sliceParams...) + results, err := source.PostgresPool().Query(ctx, listTableStats, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -233,13 +209,13 @@ func (t Tool) McpManifest() tools.McpManifest { } func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.authRequired, verifiedAuthServices) + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslisttriggers/postgreslisttriggers.go b/internal/tools/postgres/postgreslisttriggers/postgreslisttriggers.go index 8fd3f6ed17..8fc4944f73 100644 --- a/internal/tools/postgres/postgreslisttriggers/postgreslisttriggers.go +++ b/internal/tools/postgres/postgreslisttriggers/postgreslisttriggers.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -94,13 +91,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -117,18 +107,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("trigger_name", "", "Optional: A specific trigger name pattern to search for."), parameters.NewStringParameterWithDefault("schema_name", "", "Optional: A specific schema name pattern to search for."), @@ -145,7 +123,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, allParams: allParameters, - pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: allParameters.Manifest(), @@ -161,7 +138,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } @@ -171,6 +147,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -179,7 +160,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, listTriggersStatement, sliceParams...) + results, err := source.PostgresPool().Query(ctx, listTriggersStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -224,10 +205,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistviews/postgreslistviews.go b/internal/tools/postgres/postgreslistviews/postgreslistviews.go index ed2e7306dd..d0aa2438d1 100644 --- a/internal/tools/postgres/postgreslistviews/postgreslistviews.go +++ b/internal/tools/postgres/postgreslistviews/postgreslistviews.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -69,13 +66,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -92,18 +82,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("view_name", "", "Optional: A specific view name to search for."), parameters.NewStringParameterWithDefault("schema_name", "", "Optional: A specific schema name to search for."), @@ -119,7 +97,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, allParams: allParameters, - pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: paramManifest, @@ -135,12 +112,16 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -149,7 +130,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, listViewsStatement, sliceParams...) + results, err := source.PostgresPool().Query(ctx, listViewsStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -194,14 +175,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslongrunningtransactions/postgreslongrunningtransactions.go b/internal/tools/postgres/postgreslongrunningtransactions/postgreslongrunningtransactions.go index 1286bd57f6..1b2434679d 100644 --- a/internal/tools/postgres/postgreslongrunningtransactions/postgreslongrunningtransactions.go +++ b/internal/tools/postgres/postgreslongrunningtransactions/postgreslongrunningtransactions.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -76,13 +73,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -99,18 +89,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("min_duration", "5 minutes", "Optional: Only show transactions running at least this long (e.g., '1 minute', '15 minutes', '30 seconds')."), parameters.NewIntParameterWithDefault("limit", 20, "Optional: The maximum number of long-running transactions to return. Defaults to 20."), @@ -125,11 +103,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - name: cfg.Name, - kind: cfg.Kind, - authRequired: cfg.AuthRequired, - allParams: allParameters, - pool: s.PostgresPool(), + Config: cfg, + allParams: allParameters, manifest: tools.Manifest{ Description: cfg.Description, Parameters: paramManifest, @@ -144,13 +119,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - name string `yaml:"name"` - kind string `yaml:"kind"` - authRequired []string `yaml:"authRequired"` - allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool - manifest tools.Manifest - mcpManifest tools.McpManifest + allParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -158,6 +129,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -166,7 +142,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, longRunningTransactions, sliceParams...) + results, err := source.PostgresPool().Query(ctx, longRunningTransactions, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -203,13 +179,13 @@ func (t Tool) McpManifest() tools.McpManifest { } func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.authRequired, verifiedAuthServices) + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgresreplicationstats/postgresreplicationstats.go b/internal/tools/postgres/postgresreplicationstats/postgresreplicationstats.go index 2ef3e7fe3e..4280f1a0a3 100644 --- a/internal/tools/postgres/postgresreplicationstats/postgresreplicationstats.go +++ b/internal/tools/postgres/postgresreplicationstats/postgresreplicationstats.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -66,13 +63,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -89,18 +79,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{} paramManifest := allParameters.Manifest() @@ -112,11 +90,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - name: cfg.Name, - kind: cfg.Kind, - authRequired: cfg.AuthRequired, - allParams: allParameters, - pool: s.PostgresPool(), + Config: cfg, + allParams: allParameters, manifest: tools.Manifest{ Description: cfg.Description, Parameters: paramManifest, @@ -131,13 +106,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - name string `yaml:"name"` - kind string `yaml:"kind"` - authRequired []string `yaml:"authRequired"` - allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool - manifest tools.Manifest - mcpManifest tools.McpManifest + allParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -145,6 +116,11 @@ func (t Tool) ToConfig() tools.ToolConfig { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -153,7 +129,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, replicationStats, sliceParams...) + results, err := source.PostgresPool().Query(ctx, replicationStats, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -190,13 +166,13 @@ func (t Tool) McpManifest() tools.McpManifest { } func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.authRequired, verifiedAuthServices) + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgressql/postgressql.go b/internal/tools/postgres/postgressql/postgressql.go index 5e8d871372..1de22a5a82 100644 --- a/internal/tools/postgres/postgressql/postgressql.go +++ b/internal/tools/postgres/postgressql/postgressql.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -48,13 +45,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -74,18 +64,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -97,7 +75,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.PostgresPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -109,14 +86,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Pool *pgxpool.Pool + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -128,7 +108,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - results, err := t.Pool.Query(ctx, newStatement, sliceParams...) + results, err := source.PostgresPool().Query(ctx, newStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -172,14 +152,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/redis/redis.go b/internal/tools/redis/redis.go index a530eec167..6995163a6a 100644 --- a/internal/tools/redis/redis.go +++ b/internal/tools/redis/redis.go @@ -46,11 +46,6 @@ type compatibleSource interface { RedisClient() redissrc.RedisClient } -// validate compatible sources are still compatible -var _ compatibleSource = &redissrc.Source{} - -var compatibleSources = [...]string{redissrc.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -69,24 +64,11 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, cfg.Parameters, nil) // finish tool setup t := Tool{ Config: cfg, - Client: s.RedisClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -98,13 +80,16 @@ var _ tools.Tool = Tool{} type Tool struct { Config - - Client redissrc.RedisClient manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + cmds, err := replaceCommandsParams(t.Commands, t.Parameters, params) if err != nil { return nil, fmt.Errorf("error replacing commands' parameters: %s", err) @@ -113,7 +98,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Execute commands responses := make([]*redis.Cmd, len(cmds)) for i, cmd := range cmds { - responses[i] = t.Client.Do(ctx, cmd...) + responses[i] = source.RedisClient().Do(ctx, cmd...) } // Parse responses out := make([]any, len(t.Commands)) @@ -165,8 +150,8 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } // replaceCommandsParams is a helper function to replace parameters in the commands @@ -207,6 +192,6 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/serverlessspark/createbatch/config.go b/internal/tools/serverlessspark/createbatch/config.go index 54370516f9..0bb3575a39 100644 --- a/internal/tools/serverlessspark/createbatch/config.go +++ b/internal/tools/serverlessspark/createbatch/config.go @@ -19,7 +19,8 @@ import ( "encoding/json" "fmt" - dataproc "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" + dataproc "cloud.google.com/go/dataproc/v2/apiv1" + dataprocpb "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" "github.com/goccy/go-yaml" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" @@ -34,17 +35,23 @@ func unmarshalProto(data any, m proto.Message) error { return protojson.Unmarshal(jsonData, m) } +type compatibleSource interface { + GetBatchControllerClient() *dataproc.BatchControllerClient + GetProject() string + GetLocation() string +} + // Config is a common config that can be used with any type of create batch tool. However, each tool // will still need its own config type, embedding this Config, so it can provide a type-specific // Initialize implementation. type Config struct { - Name string `yaml:"name" validate:"required"` - Kind string `yaml:"kind" validate:"required"` - Source string `yaml:"source" validate:"required"` - Description string `yaml:"description"` - RuntimeConfig *dataproc.RuntimeConfig `yaml:"runtimeConfig"` - EnvironmentConfig *dataproc.EnvironmentConfig `yaml:"environmentConfig"` - AuthRequired []string `yaml:"authRequired"` + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description"` + RuntimeConfig *dataprocpb.RuntimeConfig `yaml:"runtimeConfig"` + EnvironmentConfig *dataprocpb.EnvironmentConfig `yaml:"environmentConfig"` + AuthRequired []string `yaml:"authRequired"` } func NewConfig(ctx context.Context, name string, decoder *yaml.Decoder) (Config, error) { @@ -73,7 +80,7 @@ func NewConfig(ctx context.Context, name string, decoder *yaml.Decoder) (Config, } if ymlCfg.RuntimeConfig != nil { - rc := &dataproc.RuntimeConfig{} + rc := &dataprocpb.RuntimeConfig{} if err := unmarshalProto(ymlCfg.RuntimeConfig, rc); err != nil { return Config{}, fmt.Errorf("failed to unmarshal runtimeConfig: %w", err) } @@ -81,7 +88,7 @@ func NewConfig(ctx context.Context, name string, decoder *yaml.Decoder) (Config, } if ymlCfg.EnvironmentConfig != nil { - ec := &dataproc.EnvironmentConfig{} + ec := &dataprocpb.EnvironmentConfig{} if err := unmarshalProto(ymlCfg.EnvironmentConfig, ec); err != nil { return Config{}, fmt.Errorf("failed to unmarshal environmentConfig: %w", err) } diff --git a/internal/tools/serverlessspark/createbatch/tool.go b/internal/tools/serverlessspark/createbatch/tool.go index a4e45ea64f..66702533da 100644 --- a/internal/tools/serverlessspark/createbatch/tool.go +++ b/internal/tools/serverlessspark/createbatch/tool.go @@ -20,9 +20,8 @@ import ( "fmt" "time" - dataproc "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" + dataprocpb "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/common" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -32,20 +31,10 @@ import ( type BatchBuilder interface { Parameters() parameters.Parameters - BuildBatch(params parameters.ParamValues) (*dataproc.Batch, error) + BuildBatch(params parameters.ParamValues) (*dataprocpb.Batch, error) } func NewTool(cfg Config, originalCfg tools.ToolConfig, srcs map[string]sources.Source, builder BatchBuilder) (*Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("source %q not found", cfg.Source) - } - - ds, ok := rawS.(*serverlessspark.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", cfg.Kind, serverlessspark.SourceKind) - } - desc := cfg.Description if desc == "" { desc = fmt.Sprintf("Creates a Serverless Spark (aka Dataproc Serverless) %s operation.", cfg.Kind) @@ -63,7 +52,6 @@ func NewTool(cfg Config, originalCfg tools.ToolConfig, srcs map[string]sources.S return &Tool{ Config: cfg, originalConfig: originalCfg, - Source: ds, Builder: builder, manifest: tools.Manifest{Description: desc, Parameters: allParameters.Manifest()}, mcpManifest: mcpManifest, @@ -74,17 +62,18 @@ func NewTool(cfg Config, originalCfg tools.ToolConfig, srcs map[string]sources.S type Tool struct { Config originalConfig tools.ToolConfig - - Source *serverlessspark.Source - Builder BatchBuilder - - manifest tools.Manifest - mcpManifest tools.McpManifest - Parameters parameters.Parameters + Builder BatchBuilder + manifest tools.Manifest + mcpManifest tools.McpManifest + Parameters parameters.Parameters } func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - client := t.Source.GetBatchControllerClient() + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + client := source.GetBatchControllerClient() batch, err := t.Builder.BuildBatch(params) if err != nil { @@ -92,24 +81,24 @@ func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, par } if t.RuntimeConfig != nil { - batch.RuntimeConfig = proto.Clone(t.RuntimeConfig).(*dataproc.RuntimeConfig) + batch.RuntimeConfig = proto.Clone(t.RuntimeConfig).(*dataprocpb.RuntimeConfig) } if t.EnvironmentConfig != nil { - batch.EnvironmentConfig = proto.Clone(t.EnvironmentConfig).(*dataproc.EnvironmentConfig) + batch.EnvironmentConfig = proto.Clone(t.EnvironmentConfig).(*dataprocpb.EnvironmentConfig) } // Common override for version if present in params paramMap := params.AsMap() if version, ok := paramMap["version"].(string); ok && version != "" { if batch.RuntimeConfig == nil { - batch.RuntimeConfig = &dataproc.RuntimeConfig{} + batch.RuntimeConfig = &dataprocpb.RuntimeConfig{} } batch.RuntimeConfig.Version = version } - req := &dataproc.CreateBatchRequest{ - Parent: fmt.Sprintf("projects/%s/locations/%s", t.Source.Project, t.Source.Location), + req := &dataprocpb.CreateBatchRequest{ + Parent: fmt.Sprintf("projects/%s/locations/%s", source.GetProject(), source.GetLocation()), Batch: batch, } @@ -165,14 +154,14 @@ func (t *Tool) Authorized(services []string) bool { return tools.IsAuthorized(t.AuthRequired, services) } -func (t *Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t *Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t *Tool) ToConfig() tools.ToolConfig { return t.originalConfig } -func (t *Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t *Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go b/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go index f0240f5ebd..913a8151e6 100644 --- a/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go +++ b/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go @@ -19,10 +19,10 @@ import ( "fmt" "strings" + longrunning "cloud.google.com/go/longrunning/autogen" "cloud.google.com/go/longrunning/autogen/longrunningpb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -43,6 +43,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetOperationsClient(context.Context) (*longrunning.OperationsClient, error) + GetProject() string + GetLocation() string +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -61,16 +67,6 @@ func (cfg Config) ToolConfigKind() string { // Initialize creates a new Tool instance. func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("source %q not found", cfg.Source) - } - - ds, ok := rawS.(*serverlessspark.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, serverlessspark.SourceKind) - } - desc := cfg.Description if desc == "" { desc = "Cancels a running Serverless Spark (aka Dataproc Serverless) batch operation. Note that the batch state will not change immediately after the tool returns; it can take a minute or so for the cancellation to be reflected." @@ -89,7 +85,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return &Tool{ Config: cfg, - Source: ds, manifest: tools.Manifest{Description: desc, Parameters: allParameters.Manifest()}, mcpManifest: mcpManifest, Parameters: allParameters, @@ -99,9 +94,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool is the implementation of the tool. type Tool struct { Config - - Source *serverlessspark.Source - manifest tools.Manifest mcpManifest tools.McpManifest Parameters parameters.Parameters @@ -109,7 +101,12 @@ type Tool struct { // Invoke executes the tool's operation. func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - client, err := t.Source.GetOperationsClient(ctx) + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + client, err := source.GetOperationsClient(ctx) if err != nil { return nil, fmt.Errorf("failed to get operations client: %w", err) } @@ -125,7 +122,7 @@ func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, par } req := &longrunningpb.CancelOperationRequest{ - Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", t.Source.Project, t.Source.Location, operation), + Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", source.GetProject(), source.GetLocation(), operation), } err = client.CancelOperation(ctx, req) @@ -152,15 +149,15 @@ func (t *Tool) Authorized(services []string) bool { return tools.IsAuthorized(t.AuthRequired, services) } -func (t *Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { +func (t *Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { // Client OAuth not supported, rely on ADCs. - return false + return false, nil } func (t *Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch.go b/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch.go index 558910cb9f..aebec7c9e4 100644 --- a/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch.go +++ b/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch.go @@ -20,10 +20,10 @@ import ( "fmt" "strings" + dataproc "cloud.google.com/go/dataproc/v2/apiv1" "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/common" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -46,6 +46,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetBatchControllerClient() *dataproc.BatchControllerClient + GetProject() string + GetLocation() string +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -64,16 +70,6 @@ func (cfg Config) ToolConfigKind() string { // Initialize creates a new Tool instance. func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("source %q not found", cfg.Source) - } - - ds, ok := rawS.(*serverlessspark.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, serverlessspark.SourceKind) - } - desc := cfg.Description if desc == "" { desc = "Gets a Serverless Spark (aka Dataproc Serverless) batch" @@ -92,7 +88,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: ds, manifest: tools.Manifest{Description: desc, Parameters: allParameters.Manifest()}, mcpManifest: mcpManifest, Parameters: allParameters, @@ -102,9 +97,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool is the implementation of the tool. type Tool struct { Config - - Source *serverlessspark.Source - manifest tools.Manifest mcpManifest tools.McpManifest Parameters parameters.Parameters @@ -112,7 +104,12 @@ type Tool struct { // Invoke executes the tool's operation. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - client := t.Source.GetBatchControllerClient() + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + client := source.GetBatchControllerClient() paramMap := params.AsMap() name, ok := paramMap["name"].(string) @@ -125,7 +122,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } req := &dataprocpb.GetBatchRequest{ - Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", t.Source.Project, t.Source.Location, name), + Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", source.GetProject(), source.GetLocation(), name), } batchPb, err := client.GetBatch(ctx, req) @@ -176,15 +173,15 @@ func (t Tool) Authorized(services []string) bool { return tools.IsAuthorized(t.AuthRequired, services) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { // Client OAuth not supported, rely on ADCs. - return false + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go b/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go index bb206195ca..bc8bea2caa 100644 --- a/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go +++ b/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go @@ -19,10 +19,10 @@ import ( "fmt" "time" + dataproc "cloud.google.com/go/dataproc/v2/apiv1" "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/common" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -45,6 +45,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetBatchControllerClient() *dataproc.BatchControllerClient + GetProject() string + GetLocation() string +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -63,16 +69,6 @@ func (cfg Config) ToolConfigKind() string { // Initialize creates a new Tool instance. func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("source %q not found", cfg.Source) - } - - ds, ok := rawS.(*serverlessspark.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, serverlessspark.SourceKind) - } - desc := cfg.Description if desc == "" { desc = "Lists available Serverless Spark (aka Dataproc Serverless) batches" @@ -93,7 +89,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: ds, manifest: tools.Manifest{Description: desc, Parameters: allParameters.Manifest()}, mcpManifest: mcpManifest, Parameters: allParameters, @@ -103,9 +98,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool is the implementation of the tool. type Tool struct { Config - - Source *serverlessspark.Source - manifest tools.Manifest mcpManifest tools.McpManifest Parameters parameters.Parameters @@ -131,9 +123,14 @@ type Batch struct { // Invoke executes the tool's operation. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - client := t.Source.GetBatchControllerClient() + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } - parent := fmt.Sprintf("projects/%s/locations/%s", t.Source.Project, t.Source.Location) + client := source.GetBatchControllerClient() + + parent := fmt.Sprintf("projects/%s/locations/%s", source.GetProject(), source.GetLocation()) req := &dataprocpb.ListBatchesRequest{ Parent: parent, OrderBy: "create_time desc", @@ -213,15 +210,15 @@ func (t Tool) Authorized(services []string) bool { return tools.IsAuthorized(t.AuthRequired, services) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { // Client OAuth not supported, rely on ADCs. - return false + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/singlestore/singlestoreexecutesql/singlestoreexecutesql.go b/internal/tools/singlestore/singlestoreexecutesql/singlestoreexecutesql.go index 64f3ac68cb..7ab352b195 100644 --- a/internal/tools/singlestore/singlestoreexecutesql/singlestoreexecutesql.go +++ b/internal/tools/singlestore/singlestoreexecutesql/singlestoreexecutesql.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/singlestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -48,11 +47,6 @@ type compatibleSource interface { SingleStorePool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &singlestore.Source{} - -var compatibleSources = [...]string{singlestore.SourceKind} - // Config represents the configuration for the singlestore-execute-sql tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -72,18 +66,6 @@ func (cfg Config) ToolConfigKind() string { // Initialize sets up the Tool using the provided sources map. func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") params := parameters.Parameters{sqlParameter} @@ -93,7 +75,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Pool: s.SingleStorePool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -107,7 +88,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config Parameters parameters.Parameters `yaml:"parameters"` - Pool *sql.DB manifest tools.Manifest mcpManifest tools.McpManifest } @@ -118,6 +98,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // Invoke executes the provided SQL query using the tool's database connection and returns the results. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { @@ -131,7 +116,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sql) - results, err := t.Pool.QueryContext(ctx, sql) + results, err := source.SingleStorePool().QueryContext(ctx, sql) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -199,10 +184,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/singlestore/singlestoresql/singlestoresql.go b/internal/tools/singlestore/singlestoresql/singlestoresql.go index bdb3e9f8b6..55adfe2dbf 100644 --- a/internal/tools/singlestore/singlestoresql/singlestoresql.go +++ b/internal/tools/singlestore/singlestoresql/singlestoresql.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/singlestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -47,11 +46,6 @@ type compatibleSource interface { SingleStorePool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &singlestore.Source{} - -var compatibleSources = [...]string{singlestore.SourceKind} - // Config defines the configuration for a SingleStore SQL tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -85,18 +79,6 @@ func (cfg Config) ToolConfigKind() string { // tools.Tool - the initialized tool instance. // error - an error if the source is missing, incompatible, or setup fails. func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -108,7 +90,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.SingleStorePool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -122,7 +103,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config AllParams parameters.Parameters `yaml:"allParams"` - Pool *sql.DB manifest tools.Manifest mcpManifest tools.McpManifest } @@ -146,6 +126,11 @@ func (t Tool) ToConfig() tools.ToolConfig { // - A slice of maps, where each map represents a row with column names as keys. // - An error if template resolution, parameter extraction, query execution, or result processing fails. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -158,7 +143,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.Pool.QueryContext(ctx, newStatement, sliceParams...) + results, err := source.SingleStorePool().QueryContext(ctx, newStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -226,10 +211,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/spanner/spannerexecutesql/spannerexecutesql.go b/internal/tools/spanner/spannerexecutesql/spannerexecutesql.go index 528e633fba..f0c4ce2460 100644 --- a/internal/tools/spanner/spannerexecutesql/spannerexecutesql.go +++ b/internal/tools/spanner/spannerexecutesql/spannerexecutesql.go @@ -21,7 +21,6 @@ import ( "cloud.google.com/go/spanner" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - spannerdb "github.com/googleapis/genai-toolbox/internal/sources/spanner" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/orderedmap" @@ -50,11 +49,6 @@ type compatibleSource interface { DatabaseDialect() string } -// validate compatible sources are still compatible -var _ compatibleSource = &spannerdb.Source{} - -var compatibleSources = [...]string{spannerdb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -72,18 +66,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") params := parameters.Parameters{sqlParameter} @@ -93,8 +75,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Client: s.SpannerClient(), - dialect: s.DatabaseDialect(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -107,8 +87,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config Parameters parameters.Parameters `yaml:"parameters"` - Client *spanner.Client - dialect string manifest tools.Manifest mcpManifest tools.McpManifest } @@ -138,6 +116,11 @@ func processRows(iter *spanner.RowIterator) ([]any, error) { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { @@ -156,10 +139,10 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para stmt := spanner.Statement{SQL: sql} if t.ReadOnly { - iter := t.Client.Single().Query(ctx, stmt) + iter := source.SpannerClient().Single().Query(ctx, stmt) results, opErr = processRows(iter) } else { - _, opErr = t.Client.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + _, opErr = source.SpannerClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { var err error iter := txn.Query(ctx, stmt) results, err = processRows(iter) @@ -193,14 +176,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/spanner/spannerlistgraphs/spannerlistgraphs.go b/internal/tools/spanner/spannerlistgraphs/spannerlistgraphs.go index 2ecbe06214..b9e94408e2 100644 --- a/internal/tools/spanner/spannerlistgraphs/spannerlistgraphs.go +++ b/internal/tools/spanner/spannerlistgraphs/spannerlistgraphs.go @@ -23,7 +23,6 @@ import ( "cloud.google.com/go/spanner" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - spannerdb "github.com/googleapis/genai-toolbox/internal/sources/spanner" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/iterator" @@ -50,11 +49,6 @@ type compatibleSource interface { DatabaseDialect() string } -// validate compatible sources are still compatible -var _ compatibleSource = &spannerdb.Source{} - -var compatibleSources = [...]string{spannerdb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - // Define parameters for the tool allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault( @@ -107,8 +89,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Client: s.SpannerClient(), - dialect: s.DatabaseDialect(), manifest: tools.Manifest{Description: description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -121,8 +101,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config AllParams parameters.Parameters `yaml:"allParams"` - Client *spanner.Client - dialect string manifest tools.Manifest mcpManifest tools.McpManifest } @@ -161,10 +139,16 @@ func processRows(iter *spanner.RowIterator) ([]any, error) { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - // Check dialect here at RUNTIME instead of startup - if strings.ToLower(t.dialect) != "googlesql" { - return nil, fmt.Errorf("operation not supported: The 'spanner-list-graphs' tool is only available for GoogleSQL dialect databases. Your current database dialect is '%s'", t.dialect) + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err } + + // Check dialect here at RUNTIME instead of startup + if strings.ToLower(source.DatabaseDialect()) != "googlesql" { + return nil, fmt.Errorf("operation not supported: The 'spanner-list-graphs' tool is only available for GoogleSQL dialect databases. Your current database dialect is '%s'", source.DatabaseDialect()) + } + paramsMap := params.AsMap() graphNames, _ := paramsMap["graph_names"].(string) @@ -184,7 +168,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } // Execute the query (read-only) - iter := t.Client.Single().Query(ctx, stmt) + iter := source.SpannerClient().Single().Query(ctx, stmt) results, err := processRows(iter) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) @@ -209,16 +193,16 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } // GoogleSQL statement for listing graphs diff --git a/internal/tools/spanner/spannerlisttables/spannerlisttables.go b/internal/tools/spanner/spannerlisttables/spannerlisttables.go index b5d361ea12..bd41479fed 100644 --- a/internal/tools/spanner/spannerlisttables/spannerlisttables.go +++ b/internal/tools/spanner/spannerlisttables/spannerlisttables.go @@ -23,7 +23,6 @@ import ( "cloud.google.com/go/spanner" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - spannerdb "github.com/googleapis/genai-toolbox/internal/sources/spanner" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/iterator" @@ -50,11 +49,6 @@ type compatibleSource interface { DatabaseDialect() string } -// validate compatible sources are still compatible -var _ compatibleSource = &spannerdb.Source{} - -var compatibleSources = [...]string{spannerdb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - // Define parameters for the tool allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault( @@ -107,8 +89,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Client: s.SpannerClient(), - dialect: s.DatabaseDialect(), manifest: tools.Manifest{Description: description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -121,8 +101,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config AllParams parameters.Parameters `yaml:"allParams"` - Client *spanner.Client - dialect string manifest tools.Manifest mcpManifest tools.McpManifest } @@ -160,8 +138,8 @@ func processRows(iter *spanner.RowIterator) ([]any, error) { return out, nil } -func (t Tool) getStatement() string { - switch strings.ToLower(t.dialect) { +func (t Tool) getStatement(source compatibleSource) string { + switch strings.ToLower(source.DatabaseDialect()) { case "postgresql": return postgresqlStatement case "googlesql": @@ -173,10 +151,15 @@ func (t Tool) getStatement() string { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() // Get the appropriate SQL statement based on dialect - statement := t.getStatement() + statement := t.getStatement(source) // Prepare parameters based on dialect var stmtParams map[string]interface{} @@ -187,7 +170,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para outputFormat = "detailed" } - switch strings.ToLower(t.dialect) { + switch strings.ToLower(source.DatabaseDialect()) { case "postgresql": // PostgreSQL uses positional parameters ($1, $2) stmtParams = map[string]interface{}{ @@ -202,7 +185,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para "output_format": outputFormat, } default: - return nil, fmt.Errorf("unsupported dialect: %s", t.dialect) + return nil, fmt.Errorf("unsupported dialect: %s", source.DatabaseDialect()) } stmt := spanner.Statement{ @@ -211,7 +194,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } // Execute the query (read-only) - iter := t.Client.Single().Query(ctx, stmt) + iter := source.SpannerClient().Single().Query(ctx, stmt) results, err := processRows(iter) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) @@ -236,16 +219,16 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } // PostgreSQL statement for listing tables diff --git a/internal/tools/spanner/spannersql/spannersql.go b/internal/tools/spanner/spannersql/spannersql.go index 42cdd6559c..d1b7c1ab54 100644 --- a/internal/tools/spanner/spannersql/spannersql.go +++ b/internal/tools/spanner/spannersql/spannersql.go @@ -22,7 +22,6 @@ import ( "cloud.google.com/go/spanner" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - spannerdb "github.com/googleapis/genai-toolbox/internal/sources/spanner" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/iterator" @@ -49,11 +48,6 @@ type compatibleSource interface { DatabaseDialect() string } -// validate compatible sources are still compatible -var _ compatibleSource = &spannerdb.Source{} - -var compatibleSources = [...]string{spannerdb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -74,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -97,8 +79,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Client: s.SpannerClient(), - dialect: s.DatabaseDialect(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -111,8 +91,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config AllParams parameters.Parameters `yaml:"allParams"` - Client *spanner.Client - dialect string manifest tools.Manifest mcpManifest tools.McpManifest } @@ -153,6 +131,11 @@ func processRows(iter *spanner.RowIterator) ([]any, error) { } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -187,7 +170,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para newParams[i] = parameters.ParamValue{Name: name, Value: value} } - mapParams, err := getMapParams(newParams, t.dialect) + mapParams, err := getMapParams(newParams, source.DatabaseDialect()) if err != nil { return nil, fmt.Errorf("fail to get map params: %w", err) } @@ -200,10 +183,10 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } if t.ReadOnly { - iter := t.Client.Single().Query(ctx, stmt) + iter := source.SpannerClient().Single().Query(ctx, stmt) results, opErr = processRows(iter) } else { - _, opErr = t.Client.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + _, opErr = source.SpannerClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { iter := txn.Query(ctx, stmt) results, err = processRows(iter) if err != nil { @@ -236,14 +219,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql.go b/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql.go index 848ae87125..e2c03a224a 100644 --- a/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql.go +++ b/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/sqlite" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/orderedmap" @@ -49,11 +48,6 @@ type compatibleSource interface { SQLiteDB() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &sqlite.Source{} - -var compatibleSources = [...]string{sqlite.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -70,18 +64,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") params := parameters.Parameters{sqlParameter} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) @@ -90,7 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - DB: s.SQLiteDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -102,14 +83,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - DB *sql.DB + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + sql, ok := params.AsMap()["sql"].(string) if !ok { return nil, fmt.Errorf("missing or invalid 'sql' parameter") @@ -125,7 +109,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - results, err := t.DB.QueryContext(ctx, sql) + results, err := source.SQLiteDB().QueryContext(ctx, sql) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -201,14 +185,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql_test.go b/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql_test.go index acce6527e4..63079a883e 100644 --- a/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql_test.go +++ b/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql_test.go @@ -15,19 +15,13 @@ package sqliteexecutesql_test import ( - "context" - "database/sql" - "reflect" "testing" yaml "github.com/goccy/go-yaml" "github.com/google/go-cmp/cmp" "github.com/googleapis/genai-toolbox/internal/server" "github.com/googleapis/genai-toolbox/internal/testutils" - "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqliteexecutesql" - "github.com/googleapis/genai-toolbox/internal/util/orderedmap" - "github.com/googleapis/genai-toolbox/internal/util/parameters" _ "modernc.org/sqlite" ) @@ -81,251 +75,3 @@ func TestParseFromYamlExecuteSql(t *testing.T) { } } - -func setupTestDB(t *testing.T) *sql.DB { - db, err := sql.Open("sqlite", ":memory:") - if err != nil { - t.Fatalf("Failed to open in-memory database: %v", err) - } - return db -} - -func TestTool_Invoke(t *testing.T) { - ctx, err := testutils.ContextWithNewLogger() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - - type fields struct { - Name string - Kind string - AuthRequired []string - Parameters parameters.Parameters - DB *sql.DB - } - type args struct { - ctx context.Context - params parameters.ParamValues - accessToken tools.AccessToken - } - tests := []struct { - name string - fields fields - args args - want any - wantErr bool - }{ - { - name: "create table", - fields: fields{ - DB: setupTestDB(t), - }, - args: args{ - ctx: ctx, - params: []parameters.ParamValue{ - {Name: "sql", Value: "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER)"}, - }, - }, - want: nil, - wantErr: false, - }, - { - name: "insert data", - fields: fields{ - DB: setupTestDB(t), - }, - args: args{ - ctx: ctx, - params: []parameters.ParamValue{ - {Name: "sql", Value: "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER); INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30), (2, 'Bob', 25)"}, - }, - }, - want: nil, - wantErr: false, - }, - { - name: "select data", - fields: fields{ - DB: func() *sql.DB { - db := setupTestDB(t) - if _, err := db.Exec("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER); INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30), (2, 'Bob', 25)"); err != nil { - t.Fatalf("Failed to set up database for select: %v", err) - } - return db - }(), - }, - args: args{ - ctx: ctx, - params: []parameters.ParamValue{ - {Name: "sql", Value: "SELECT * FROM users"}, - }, - }, - want: []any{ - orderedmap.Row{ - Columns: []orderedmap.Column{ - {Name: "id", Value: int64(1)}, - {Name: "name", Value: "Alice"}, - {Name: "age", Value: int64(30)}, - }, - }, - orderedmap.Row{ - Columns: []orderedmap.Column{ - {Name: "id", Value: int64(2)}, - {Name: "name", Value: "Bob"}, - {Name: "age", Value: int64(25)}, - }, - }, - }, - wantErr: false, - }, - { - name: "drop table", - fields: fields{ - DB: func() *sql.DB { - db := setupTestDB(t) - if _, err := db.Exec("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER)"); err != nil { - t.Fatalf("Failed to set up database for drop: %v", err) - } - return db - }(), - }, - args: args{ - ctx: ctx, - params: []parameters.ParamValue{ - {Name: "sql", Value: "DROP TABLE users"}, - }, - }, - want: nil, - wantErr: false, - }, - { - name: "invalid sql", - fields: fields{ - DB: setupTestDB(t), - }, - args: args{ - ctx: ctx, - params: []parameters.ParamValue{ - {Name: "sql", Value: "SELECT * FROM non_existent_table"}, - }, - }, - want: nil, - wantErr: true, - }, - { - name: "empty sql", - fields: fields{ - DB: setupTestDB(t), - }, - args: args{ - ctx: ctx, - params: []parameters.ParamValue{ - {Name: "sql", Value: ""}, - }, - }, - want: nil, - wantErr: true, - }, - { - name: "data types", - fields: fields{ - DB: func() *sql.DB { - db := setupTestDB(t) - if _, err := db.Exec("CREATE TABLE data_types (id INTEGER PRIMARY KEY, null_col TEXT, blob_col BLOB)"); err != nil { - t.Fatalf("Failed to set up database for data types: %v", err) - } - if _, err := db.Exec("INSERT INTO data_types (id, null_col, blob_col) VALUES (1, NULL, ?)", []byte{1, 2, 3}); err != nil { - t.Fatalf("Failed to insert data for data types: %v", err) - } - return db - }(), - }, - args: args{ - ctx: ctx, - params: []parameters.ParamValue{ - {Name: "sql", Value: "SELECT * FROM data_types"}, - }, - }, - want: []any{ - orderedmap.Row{ - Columns: []orderedmap.Column{ - {Name: "id", Value: int64(1)}, - {Name: "null_col", Value: nil}, - {Name: "blob_col", Value: []byte{1, 2, 3}}, - }, - }, - }, - wantErr: false, - }, - { - name: "join operation", - fields: fields{ - DB: func() *sql.DB { - db := setupTestDB(t) - if _, err := db.Exec("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER)"); err != nil { - t.Fatalf("Failed to set up database for join: %v", err) - } - if _, err := db.Exec("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30), (2, 'Bob', 25)"); err != nil { - t.Fatalf("Failed to insert data for join: %v", err) - } - if _, err := db.Exec("CREATE TABLE orders (id INTEGER PRIMARY KEY, user_id INTEGER, item TEXT)"); err != nil { - t.Fatalf("Failed to set up database for join: %v", err) - } - if _, err := db.Exec("INSERT INTO orders (id, user_id, item) VALUES (1, 1, 'Laptop'), (2, 2, 'Keyboard')"); err != nil { - t.Fatalf("Failed to insert data for join: %v", err) - } - return db - }(), - }, - args: args{ - ctx: ctx, - params: []parameters.ParamValue{ - {Name: "sql", Value: "SELECT u.name, o.item FROM users u JOIN orders o ON u.id = o.user_id"}, - }, - }, - want: []any{ - orderedmap.Row{ - Columns: []orderedmap.Column{ - {Name: "name", Value: "Alice"}, - {Name: "item", Value: "Laptop"}, - }, - }, - orderedmap.Row{ - Columns: []orderedmap.Column{ - {Name: "name", Value: "Bob"}, - {Name: "item", Value: "Keyboard"}, - }, - }, - }, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tr := sqliteexecutesql.Tool{ - Config: sqliteexecutesql.Config{ - Name: tt.fields.Name, - Kind: tt.fields.Kind, - AuthRequired: tt.fields.AuthRequired, - }, - Parameters: tt.fields.Parameters, - DB: tt.fields.DB, - } - got, err := tr.Invoke(tt.args.ctx, nil, tt.args.params, tt.args.accessToken) - if (err != nil) != tt.wantErr { - t.Errorf("Tool.Invoke() error = %v, wantErr %v", err, tt.wantErr) - return - } - isEqual := false - if got != nil && len(got.([]any)) == 0 && len(tt.want.([]any)) == 0 { - isEqual = true // Special case for empty slices, since DeepEqual returns false - } else { - isEqual = reflect.DeepEqual(got, tt.want) - } - - if !isEqual { - t.Errorf("Tool.Invoke() = %+v, want %v", got, tt.want) - } - }) - } -} diff --git a/internal/tools/sqlite/sqlitesql/sqlitesql.go b/internal/tools/sqlite/sqlitesql/sqlitesql.go index 7a2f32ed40..e715252dc4 100644 --- a/internal/tools/sqlite/sqlitesql/sqlitesql.go +++ b/internal/tools/sqlite/sqlitesql/sqlitesql.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/sqlite" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -47,11 +46,6 @@ type compatibleSource interface { SQLiteDB() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &sqlite.Source{} - -var compatibleSources = [...]string{sqlite.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -94,7 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Db: s.SQLiteDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -106,14 +87,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Db *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -126,7 +110,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } // Execute the SQL query with parameters - rows, err := t.Db.QueryContext(ctx, newStatement, newParams.AsSlice()...) + rows, err := source.SQLiteDB().QueryContext(ctx, newStatement, newParams.AsSlice()...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -200,14 +184,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/sqlite/sqlitesql/sqlitesql_test.go b/internal/tools/sqlite/sqlitesql/sqlitesql_test.go index d446e20496..eea6fddf4f 100644 --- a/internal/tools/sqlite/sqlitesql/sqlitesql_test.go +++ b/internal/tools/sqlite/sqlitesql/sqlitesql_test.go @@ -15,16 +15,12 @@ package sqlitesql_test import ( - "context" - "database/sql" - "reflect" "testing" yaml "github.com/goccy/go-yaml" "github.com/google/go-cmp/cmp" "github.com/googleapis/genai-toolbox/internal/server" "github.com/googleapis/genai-toolbox/internal/testutils" - "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqlitesql" "github.com/googleapis/genai-toolbox/internal/util/parameters" _ "modernc.org/sqlite" @@ -179,148 +175,3 @@ func TestParseFromYamlWithTemplateSqlite(t *testing.T) { }) } } - -func setupTestDB(t *testing.T) *sql.DB { - db, err := sql.Open("sqlite", ":memory:") - if err != nil { - t.Fatalf("Failed to open in-memory database: %v", err) - } - - createTable := ` - CREATE TABLE users ( - id INTEGER PRIMARY KEY, - name TEXT, - age INTEGER - );` - if _, err := db.Exec(createTable); err != nil { - t.Fatalf("Failed to create table: %v", err) - } - - insertData := ` - INSERT INTO users (id, name, age) VALUES - (1, 'Alice', 30), - (2, 'Bob', 25);` - if _, err := db.Exec(insertData); err != nil { - t.Fatalf("Failed to insert data: %v", err) - } - - return db -} - -func TestTool_Invoke(t *testing.T) { - type fields struct { - Name string - Kind string - AuthRequired []string - Parameters parameters.Parameters - TemplateParameters parameters.Parameters - AllParams parameters.Parameters - Db *sql.DB - Statement string - } - type args struct { - ctx context.Context - params parameters.ParamValues - accessToken tools.AccessToken - } - tests := []struct { - name string - fields fields - args args - want any - wantErr bool - }{ - { - name: "simple select", - fields: fields{ - Db: setupTestDB(t), - Statement: "SELECT * FROM users", - }, - args: args{ - ctx: context.Background(), - }, - want: []any{ - map[string]any{"id": int64(1), "name": "Alice", "age": int64(30)}, - map[string]any{"id": int64(2), "name": "Bob", "age": int64(25)}, - }, - wantErr: false, - }, - { - name: "select with parameter", - fields: fields{ - Db: setupTestDB(t), - Statement: "SELECT * FROM users WHERE name = ?", - Parameters: []parameters.Parameter{ - parameters.NewStringParameter("name", "user name"), - }, - }, - args: args{ - ctx: context.Background(), - params: []parameters.ParamValue{ - {Name: "name", Value: "Alice"}, - }, - }, - want: []any{ - map[string]any{"id": int64(1), "name": "Alice", "age": int64(30)}, - }, - wantErr: false, - }, - { - name: "select with template parameter", - fields: fields{ - Db: setupTestDB(t), - Statement: "SELECT * FROM {{.tableName}}", - TemplateParameters: []parameters.Parameter{ - parameters.NewStringParameter("tableName", "table name"), - }, - }, - args: args{ - ctx: context.Background(), - params: []parameters.ParamValue{ - {Name: "tableName", Value: "users"}, - }, - }, - want: []any{ - map[string]any{"id": int64(1), "name": "Alice", "age": int64(30)}, - map[string]any{"id": int64(2), "name": "Bob", "age": int64(25)}, - }, - wantErr: false, - }, - { - name: "invalid sql", - fields: fields{ - Db: setupTestDB(t), - Statement: "SELECT * FROM non_existent_table", - }, - args: args{ - ctx: context.Background(), - }, - want: nil, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tr := sqlitesql.Tool{ - Config: sqlitesql.Config{ - Name: tt.fields.Name, - Kind: tt.fields.Kind, - AuthRequired: tt.fields.AuthRequired, - Statement: tt.fields.Statement, - Parameters: tt.fields.Parameters, - TemplateParameters: tt.fields.TemplateParameters, - }, - AllParams: tt.fields.AllParams, - Db: tt.fields.Db, - } - got, err := tr.Invoke(tt.args.ctx, nil, tt.args.params, tt.args.accessToken) - if (err != nil) != tt.wantErr { - t.Errorf("Tool.Invoke() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Tool.Invoke() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/internal/tools/tidb/tidbexecutesql/tidbexecutesql.go b/internal/tools/tidb/tidbexecutesql/tidbexecutesql.go index 45b53714df..b452de841d 100644 --- a/internal/tools/tidb/tidbexecutesql/tidbexecutesql.go +++ b/internal/tools/tidb/tidbexecutesql/tidbexecutesql.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/tidb" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -47,11 +46,6 @@ type compatibleSource interface { TiDBPool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &tidb.Source{} - -var compatibleSources = [...]string{tidb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -68,18 +62,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") params := parameters.Parameters{sqlParameter} @@ -89,7 +71,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Pool: s.TiDBPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -101,14 +82,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Pool *sql.DB + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { @@ -122,7 +106,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - results, err := t.Pool.QueryContext(ctx, sql) + results, err := source.TiDBPool().QueryContext(ctx, sql) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -194,14 +178,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/tidb/tidbsql/tidbsql.go b/internal/tools/tidb/tidbsql/tidbsql.go index 01c6ef2cd0..f35d0a61db 100644 --- a/internal/tools/tidb/tidbsql/tidbsql.go +++ b/internal/tools/tidb/tidbsql/tidbsql.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/tidb" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -47,11 +46,6 @@ type compatibleSource interface { TiDBPool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &tidb.Source{} - -var compatibleSources = [...]string{tidb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -94,7 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.TiDBPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -106,14 +87,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Pool *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -126,7 +110,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := t.Pool.QueryContext(ctx, newStatement, sliceParams...) + results, err := source.TiDBPool().QueryContext(ctx, newStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -206,14 +190,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/tools.go b/internal/tools/tools.go index bf99b3f33a..7283655f0c 100644 --- a/internal/tools/tools.go +++ b/internal/tools/tools.go @@ -90,9 +90,9 @@ type Tool interface { Manifest() Manifest McpManifest() McpManifest Authorized([]string) bool - RequiresClientAuthorization(SourceProvider) bool + RequiresClientAuthorization(SourceProvider) (bool, error) ToConfig() ToolConfig - GetAuthTokenHeaderName() string + GetAuthTokenHeaderName(SourceProvider) (string, error) } // SourceProvider defines the minimal view of the server.ResourceManager @@ -157,3 +157,16 @@ func IsAuthorized(authRequiredSources []string, verifiedAuthServices []string) b } return false } + +func GetCompatibleSource[T any](resourceMgr SourceProvider, sourceName, toolName, toolKind string) (T, error) { + var zero T + s, ok := resourceMgr.GetSource(sourceName) + if !ok { + return zero, fmt.Errorf("unable to retrieve source %q for tool %q", sourceName, toolName) + } + source, ok := s.(T) + if !ok { + return zero, fmt.Errorf("invalid source for %q tool: source %q is not a compatible type", toolKind, sourceName) + } + return source, nil +} diff --git a/internal/tools/trino/trinoexecutesql/trinoexecutesql.go b/internal/tools/trino/trinoexecutesql/trinoexecutesql.go index 771a880d81..f9f396bd03 100644 --- a/internal/tools/trino/trinoexecutesql/trinoexecutesql.go +++ b/internal/tools/trino/trinoexecutesql/trinoexecutesql.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/trino" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -46,11 +45,6 @@ type compatibleSource interface { TrinoDB() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &trino.Source{} - -var compatibleSources = [...]string{trino.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -67,18 +61,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The SQL query to execute against the Trino database.") params := parameters.Parameters{sqlParameter} @@ -88,7 +70,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Db: s.TrinoDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -100,21 +81,24 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Db *sql.DB + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + sliceParams := params.AsSlice() sql, ok := sliceParams[0].(string) if !ok { return nil, fmt.Errorf("unable to cast sql parameter: %v", sliceParams[0]) } - results, err := t.Db.QueryContext(ctx, sql) + results, err := source.TrinoDB().QueryContext(ctx, sql) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -179,14 +163,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/trino/trinosql/trinosql.go b/internal/tools/trino/trinosql/trinosql.go index 9528b6dc33..7dd06d505c 100644 --- a/internal/tools/trino/trinosql/trinosql.go +++ b/internal/tools/trino/trinosql/trinosql.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/trino" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -46,11 +45,6 @@ type compatibleSource interface { TrinoDB() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &trino.Source{} - -var compatibleSources = [...]string{trino.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -70,18 +64,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, fmt.Errorf("unable to process parameters: %w", err) @@ -93,7 +75,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Db: s.TrinoDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -105,14 +86,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Db *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -123,7 +107,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - results, err := t.Db.QueryContext(ctx, newStatement, sliceParams...) + results, err := source.TrinoDB().QueryContext(ctx, newStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -188,14 +172,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/utility/wait/wait.go b/internal/tools/utility/wait/wait.go index 8c49762b34..5b931ebcaf 100644 --- a/internal/tools/utility/wait/wait.go +++ b/internal/tools/utility/wait/wait.go @@ -114,14 +114,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/valkey/valkey.go b/internal/tools/valkey/valkey.go index 8b350f6375..8f9d90c264 100644 --- a/internal/tools/valkey/valkey.go +++ b/internal/tools/valkey/valkey.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - valkeysrc "github.com/googleapis/genai-toolbox/internal/sources/valkey" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/valkey-io/valkey-go" @@ -45,11 +44,6 @@ type compatibleSource interface { ValkeyClient() valkey.Client } -// validate compatible sources are still compatible -var _ compatibleSource = &valkeysrc.Source{} - -var compatibleSources = [...]string{valkeysrc.SourceKind, valkeysrc.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -68,24 +62,11 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, cfg.Parameters, nil) // finish tool setup t := Tool{ Config: cfg, - Client: s.ValkeyClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -97,13 +78,16 @@ var _ tools.Tool = Tool{} type Tool struct { Config - - Client valkey.Client manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + // Replace parameters commands, err := replaceCommandsParams(t.Commands, t.Parameters, params) if err != nil { @@ -114,7 +98,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para builtCmds := make(valkey.Commands, len(commands)) for i, cmd := range commands { - builtCmds[i] = t.Client.B().Arbitrary(cmd...).Build() + builtCmds[i] = source.ValkeyClient().B().Arbitrary(cmd...).Build() } if len(builtCmds) == 0 { @@ -122,7 +106,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } // Execute commands - responses := t.Client.DoMulti(ctx, builtCmds...) + responses := source.ValkeyClient().DoMulti(ctx, builtCmds...) // Parse responses out := make([]any, len(t.Commands)) @@ -193,14 +177,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/yugabytedbsql/yugabytedbsql.go b/internal/tools/yugabytedbsql/yugabytedbsql.go index 4564a62e05..3b774ac366 100644 --- a/internal/tools/yugabytedbsql/yugabytedbsql.go +++ b/internal/tools/yugabytedbsql/yugabytedbsql.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/yugabytedb" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/yugabyte/pgx/v5/pgxpool" @@ -46,8 +45,6 @@ type compatibleSource interface { YugabyteDBPool() *pgxpool.Pool } -var compatibleSources = [...]string{yugabytedb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -67,18 +64,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -90,7 +75,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.YugabyteDBPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -102,14 +86,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Pool *pgxpool.Pool + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -121,7 +108,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - results, err := t.Pool.Query(ctx, newStatement, sliceParams...) + results, err := source.YugabyteDBPool().Query(ctx, newStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -165,14 +152,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/util/util.go b/internal/util/util.go index 9b0f269ce7..657fe8bf29 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "io" + "net/http" "strings" "github.com/go-playground/validator/v10" @@ -119,6 +120,30 @@ func UserAgentFromContext(ctx context.Context) (string, error) { } } +type UserAgentRoundTripper struct { + userAgent string + next http.RoundTripper +} + +func NewUserAgentRoundTripper(ua string, next http.RoundTripper) *UserAgentRoundTripper { + return &UserAgentRoundTripper{ + userAgent: ua, + next: next, + } +} + +func (rt *UserAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // create a deep copy of the request + newReq := req.Clone(req.Context()) + ua := newReq.Header.Get("User-Agent") + if ua == "" { + newReq.Header.Set("User-Agent", rt.userAgent) + } else { + newReq.Header.Set("User-Agent", ua+" "+rt.userAgent) + } + return rt.next.RoundTrip(newReq) +} + func NewStrictDecoder(v interface{}) (*yaml.Decoder, error) { b, err := yaml.Marshal(v) if err != nil { diff --git a/server.json b/server.json index 9ba5d9657d..fe2dfd9a82 100644 --- a/server.json +++ b/server.json @@ -14,11 +14,11 @@ "url": "https://github.com/googleapis/genai-toolbox", "source": "github" }, - "version": "0.23.0", + "version": "0.24.0", "packages": [ { "registryType": "oci", - "identifier": "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:0.23.0", + "identifier": "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:0.24.0", "transport": { "type": "streamable-http", "url": "http://{host}:{port}/mcp" diff --git a/tests/clickhouse/clickhouse_integration_test.go b/tests/clickhouse/clickhouse_integration_test.go index 5391590edc..058e4d1b1a 100644 --- a/tests/clickhouse/clickhouse_integration_test.go +++ b/tests/clickhouse/clickhouse_integration_test.go @@ -15,9 +15,12 @@ package clickhouse import ( + "bytes" "context" "database/sql" + "encoding/json" "fmt" + "net/http" "os" "regexp" "strings" @@ -26,16 +29,9 @@ import ( _ "github.com/ClickHouse/clickhouse-go/v2" "github.com/google/uuid" - "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/clickhouse" "github.com/googleapis/genai-toolbox/internal/testutils" - clickhouseexecutesql "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouseexecutesql" - clickhouselistdatabases "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselistdatabases" - clickhouselisttables "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselisttables" - clickhousesql "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhousesql" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/googleapis/genai-toolbox/tests" - "go.opentelemetry.io/otel/trace/noop" ) var ( @@ -384,150 +380,125 @@ func TestClickHouseSQLTool(t *testing.T) { t.Fatalf("Failed to insert test data: %v", err) } - t.Run("SimpleSelect", func(t *testing.T) { - toolConfig := clickhousesql.Config{ - Name: "test-select", - Kind: "clickhouse-sql", - Source: "test-clickhouse", - Description: "Test select query", - Statement: fmt.Sprintf("SELECT * FROM %s ORDER BY id", tableName), - } - - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } - - result, err := tool.Invoke(ctx, nil, parameters.ParamValues{}, "") - if err != nil { - t.Fatalf("Failed to invoke tool: %v", err) - } - - resultSlice, ok := result.([]any) - if !ok { - t.Fatalf("Expected result to be []any, got %T", result) - } - - if len(resultSlice) != 3 { - t.Errorf("Expected 3 results, got %d", len(resultSlice)) - } - }) - - t.Run("ParameterizedQuery", func(t *testing.T) { - toolConfig := clickhousesql.Config{ - Name: "test-param-query", - Kind: "clickhouse-sql", - Source: "test-clickhouse", - Description: "Test parameterized query", - Statement: fmt.Sprintf("SELECT * FROM %s WHERE age > ? ORDER BY id", tableName), - Parameters: parameters.Parameters{ - parameters.NewIntParameter("min_age", "Minimum age"), + toolsFile := map[string]any{ + "sources": map[string]any{ + "my-instance": getClickHouseVars(t), + }, + "tools": map[string]any{ + "test-select": map[string]any{ + "kind": ClickHouseToolKind, + "source": "my-instance", + "description": "Test select query", + "statement": fmt.Sprintf("SELECT * FROM %s ORDER BY id", tableName), }, - } - - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } - - params := parameters.ParamValues{ - {Name: "min_age", Value: 28}, - } - - result, err := tool.Invoke(ctx, nil, params, "") - if err != nil { - t.Fatalf("Failed to invoke tool: %v", err) - } - - resultSlice, ok := result.([]any) - if !ok { - t.Fatalf("Expected result to be []any, got %T", result) - } - - if len(resultSlice) != 2 { - t.Errorf("Expected 2 results (Bob and Charlie), got %d", len(resultSlice)) - } - }) - - t.Run("EmptyResult", func(t *testing.T) { - toolConfig := clickhousesql.Config{ - Name: "test-empty-result", - Kind: "clickhouse-sql", - Source: "test-clickhouse", - Description: "Test query with no results", - Statement: fmt.Sprintf("SELECT * FROM %s WHERE id = ?", tableName), - Parameters: parameters.Parameters{ - parameters.NewIntParameter("id", "Record ID"), + "test-param-query": map[string]any{ + "kind": ClickHouseToolKind, + "source": "my-instance", + "description": "Test parameterized query", + "statement": fmt.Sprintf("SELECT * FROM %s WHERE age > ? ORDER BY id", tableName), + "parameters": []parameters.Parameter{ + parameters.NewIntParameter("min_age", "Minimum age"), + }, }, - } + "test-empty-result": map[string]any{ + "kind": ClickHouseToolKind, + "source": "my-instance", + "description": "Test query with no results", + "statement": fmt.Sprintf("SELECT * FROM %s WHERE id = ?", tableName), + "parameters": []parameters.Parameter{ + parameters.NewIntParameter("id", "Record ID"), + }, + }, + "test-invalid-sql": map[string]any{ + "kind": ClickHouseToolKind, + "source": "my-instance", + "description": "Test invalid SQL", + "statement": "SELEC * FROM nonexistent_table", // Typo in SELECT + }, + }, + } - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } + var args []string + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } - params := parameters.ParamValues{ - {Name: "id", Value: 999}, // Non-existent ID - } - - result, err := tool.Invoke(ctx, nil, params, "") - if err != nil { - t.Fatalf("Failed to invoke tool: %v", err) - } - - // ClickHouse returns empty slice for no results, not nil - if resultSlice, ok := result.([]any); ok { - if len(resultSlice) != 0 { - t.Errorf("Expected empty result for non-existent record, got %d results", len(resultSlice)) + tcs := []struct { + name string + toolName string + requestBody []byte + resultSliceLen int + isErr bool + }{ + { + name: "SimpleSelect", + toolName: "test-select", + requestBody: []byte(`{}`), + resultSliceLen: 3, + }, + { + name: "ParameterizedQuery", + toolName: "test-param-query", + requestBody: []byte(`{"min_age": 28}`), + resultSliceLen: 2, + }, + { + name: "EmptyResult", + toolName: "test-empty-result", + requestBody: []byte(`{"id": 999}`), // non-existent id + resultSliceLen: 0, + }, + { + name: "InvalidSQL", + toolName: "test-invalid-sql", + requestBody: []byte(``), + isErr: true, + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + api := fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", tc.toolName) + resp, respBody := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer(tc.requestBody), nil) + if resp.StatusCode != http.StatusOK { + if tc.isErr { + return + } + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } - } else if result != nil { - t.Errorf("Expected empty slice or nil result for empty query, got %v", result) - } - }) - t.Run("InvalidSQL", func(t *testing.T) { - toolConfig := clickhousesql.Config{ - Name: "test-invalid-sql", - Kind: "clickhouse-sql", - Source: "test-clickhouse", - Description: "Test invalid SQL", - Statement: "SELEC * FROM nonexistent_table", // Typo in SELECT - } + var body map[string]interface{} + err := json.Unmarshal(respBody, &body) + if err != nil { + t.Fatalf("error parsing response body") + } - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } + got, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + t.Logf("result is %s", got) - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } + var res []any + err = json.Unmarshal([]byte(got), &res) + if err != nil { + t.Fatalf("error parsing result") + } - _, err = tool.Invoke(ctx, nil, parameters.ParamValues{}, "") - if err == nil { - t.Error("Expected error for invalid SQL, got nil") - } - - if !strings.Contains(err.Error(), "Syntax error") && !strings.Contains(err.Error(), "SELEC") { - t.Errorf("Expected syntax error message, got: %v", err) - } - }) + if len(res) != tc.resultSliceLen { + t.Errorf("Expected %d results, got %d", tc.resultSliceLen, len(res)) + } + }) + } t.Logf("✅ clickhouse-sql tool tests completed successfully") } @@ -545,224 +516,108 @@ func TestClickHouseExecuteSQLTool(t *testing.T) { tableName := "test_exec_sql_" + strings.ReplaceAll(uuid.New().String(), "-", "") - t.Run("CreateTable", func(t *testing.T) { - toolConfig := clickhouseexecutesql.Config{ - Name: "test-create-table", - Kind: "clickhouse-execute-sql", - Source: "test-clickhouse", - Description: "Test create table", - } + toolsFile := map[string]any{ + "sources": map[string]any{ + "my-instance": getClickHouseVars(t), + }, + "tools": map[string]any{ + "execute-sql-tool": map[string]any{ + "kind": "clickhouse-execute-sql", + "source": "my-instance", + "description": "Test create table", + }, + }, + } - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } + var args []string + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + tcs := []struct { + name string + sql string + resultSliceLen int + isErr bool + }{ + { + name: "CreateTable", + sql: fmt.Sprintf(`CREATE TABLE %s (id UInt32, data String) ENGINE = Memory`, tableName), + resultSliceLen: 0, + }, + { + name: "InsertData", + sql: fmt.Sprintf("INSERT INTO %s (id, data) VALUES (1, 'test1'), (2, 'test2')", tableName), + resultSliceLen: 0, + }, + { + name: "SelectData", + sql: fmt.Sprintf("SELECT * FROM %s ORDER BY id", tableName), + resultSliceLen: 2, + }, + { + name: "DropTable", + sql: fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName), + resultSliceLen: 0, + }, + { + name: "MissingSQL", + sql: "", + isErr: true, + }, - createSQL := fmt.Sprintf(` - CREATE TABLE %s ( - id UInt32, - data String - ) ENGINE = Memory - `, tableName) - - params := parameters.ParamValues{ - {Name: "sql", Value: createSQL}, - } - - result, err := tool.Invoke(ctx, nil, params, "") - if err != nil { - t.Fatalf("Failed to create table: %v", err) - } - - // CREATE TABLE should return nil or empty slice (no rows) - if resultSlice, ok := result.([]any); ok { - if len(resultSlice) != 0 { - t.Errorf("Expected empty result for CREATE TABLE, got %d results", len(resultSlice)) + { + name: "SQLInjectionAttempt", + sql: "SELECT 1; DROP TABLE system.users; SELECT 2", + isErr: true, + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + param := fmt.Sprintf(`{"sql": "%s"}`, tc.sql) + api := "http://127.0.0.1:5000/api/tool/execute-sql-tool/invoke" + resp, respBody := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer([]byte(param)), nil) + if resp.StatusCode != http.StatusOK { + if tc.isErr { + return + } + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } - } else if result != nil { - t.Errorf("Expected nil or empty slice for CREATE TABLE, got %v", result) - } - }) - - t.Run("InsertData", func(t *testing.T) { - toolConfig := clickhouseexecutesql.Config{ - Name: "test-insert", - Kind: "clickhouse-execute-sql", - Source: "test-clickhouse", - Description: "Test insert data", - } - - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } - - insertSQL := fmt.Sprintf("INSERT INTO %s (id, data) VALUES (1, 'test1'), (2, 'test2')", tableName) - params := parameters.ParamValues{ - {Name: "sql", Value: insertSQL}, - } - - result, err := tool.Invoke(ctx, nil, params, "") - if err != nil { - t.Fatalf("Failed to insert data: %v", err) - } - - // INSERT should return nil or empty slice - if resultSlice, ok := result.([]any); ok { - if len(resultSlice) != 0 { - t.Errorf("Expected empty result for INSERT, got %d results", len(resultSlice)) + if tc.isErr { + t.Fatalf("expecting an error from server") } - } else if result != nil { - t.Errorf("Expected nil or empty slice for INSERT, got %v", result) - } - }) - t.Run("SelectData", func(t *testing.T) { - toolConfig := clickhouseexecutesql.Config{ - Name: "test-select", - Kind: "clickhouse-execute-sql", - Source: "test-clickhouse", - Description: "Test select data", - } - - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } - - selectSQL := fmt.Sprintf("SELECT * FROM %s ORDER BY id", tableName) - params := parameters.ParamValues{ - {Name: "sql", Value: selectSQL}, - } - - result, err := tool.Invoke(ctx, nil, params, "") - if err != nil { - t.Fatalf("Failed to select data: %v", err) - } - - resultSlice, ok := result.([]any) - if !ok { - t.Fatalf("Expected result to be []any, got %T", result) - } - - if len(resultSlice) != 2 { - t.Errorf("Expected 2 results, got %d", len(resultSlice)) - } - }) - - t.Run("DropTable", func(t *testing.T) { - toolConfig := clickhouseexecutesql.Config{ - Name: "test-drop-table", - Kind: "clickhouse-execute-sql", - Source: "test-clickhouse", - Description: "Test drop table", - } - - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } - - dropSQL := fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName) - params := parameters.ParamValues{ - {Name: "sql", Value: dropSQL}, - } - - result, err := tool.Invoke(ctx, nil, params, "") - if err != nil { - t.Fatalf("Failed to drop table: %v", err) - } - - // DROP TABLE should return nil or empty slice - if resultSlice, ok := result.([]any); ok { - if len(resultSlice) != 0 { - t.Errorf("Expected empty result for DROP TABLE, got %d results", len(resultSlice)) + var body map[string]interface{} + err := json.Unmarshal(respBody, &body) + if err != nil { + t.Fatalf("error parsing response body") } - } else if result != nil { - t.Errorf("Expected nil or empty slice for DROP TABLE, got %v", result) - } - }) - t.Run("MissingSQL", func(t *testing.T) { - toolConfig := clickhouseexecutesql.Config{ - Name: "test-missing-sql", - Kind: "clickhouse-execute-sql", - Source: "test-clickhouse", - Description: "Test missing SQL parameter", - } + got, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } + var res []any + err = json.Unmarshal([]byte(got), &res) + if err != nil { + t.Fatalf("error parsing result") + } - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } - - // Pass empty SQL parameter - this should cause an error - params := parameters.ParamValues{ - {Name: "sql", Value: ""}, - } - - _, err = tool.Invoke(ctx, nil, params, "") - if err == nil { - t.Error("Expected error for empty SQL parameter, got nil") - } else { - t.Logf("Got expected error for empty SQL parameter: %v", err) - } - }) - - t.Run("SQLInjectionAttempt", func(t *testing.T) { - toolConfig := clickhouseexecutesql.Config{ - Name: "test-sql-injection", - Kind: "clickhouse-execute-sql", - Source: "test-clickhouse", - Description: "Test SQL injection attempt", - } - - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } - - // Try to execute multiple statements (should fail or execute safely) - injectionSQL := "SELECT 1; DROP TABLE system.users; SELECT 2" - params := parameters.ParamValues{ - {Name: "sql", Value: injectionSQL}, - } - - _, err = tool.Invoke(ctx, nil, params, "") - // This should either fail or only execute the first statement - // dont check the specific error as behavior may vary - _ = err // We're not checking the error intentionally - }) + if len(res) != tc.resultSliceLen { + t.Errorf("Expected %d results, got %d", tc.resultSliceLen, len(res)) + } + }) + } t.Logf("✅ clickhouse-execute-sql tool tests completed successfully") } @@ -778,6 +633,49 @@ func TestClickHouseEdgeCases(t *testing.T) { } defer pool.Close() + tableName := "test_nulls_" + strings.ReplaceAll(uuid.New().String(), "-", "") + toolsFile := map[string]any{ + "sources": map[string]any{ + "my-instance": getClickHouseVars(t), + }, + "tools": map[string]any{ + "execute-sql-tool": map[string]any{ + "kind": "clickhouse-execute-sql", + "source": "my-instance", + "description": "Test create table", + }, + "test-null-values": map[string]any{ + "kind": "clickhouse-sql", + "source": "my-instance", + "description": "Test null values", + "statement": fmt.Sprintf("SELECT * FROM %s ORDER BY id", tableName), + }, + "test-concurrent": map[string]any{ + "kind": "clickhouse-sql", + "source": "my-instance", + "description": "Test concurrent queries", + "statement": "SELECT number FROM system.numbers LIMIT ?", + "parameters": []parameters.Parameter{ + parameters.NewIntParameter("limit", "Limit"), + }, + }, + }, + } + + var args []string + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } t.Run("VeryLongQuery", func(t *testing.T) { // Create a very long but valid query var conditions []string @@ -786,42 +684,37 @@ func TestClickHouseEdgeCases(t *testing.T) { } longQuery := "SELECT 1 WHERE " + strings.Join(conditions, " AND ") - toolConfig := clickhouseexecutesql.Config{ - Name: "test-long-query", - Kind: "clickhouse-execute-sql", - Source: "test-clickhouse", - Description: "Test very long query", + api := "http://127.0.0.1:5000/api/tool/execute-sql-tool/invoke" + param := fmt.Sprintf(`{"sql": "%s"}`, longQuery) + resp, respBody := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer([]byte(param)), nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) + var body map[string]interface{} + err := json.Unmarshal(respBody, &body) if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) + t.Fatalf("error parsing response body") } - params := parameters.ParamValues{ - {Name: "sql", Value: longQuery}, + got, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") } - result, err := tool.Invoke(ctx, nil, params, "") + var res []any + err = json.Unmarshal([]byte(got), &res) if err != nil { - t.Fatalf("Failed to execute long query: %v", err) + t.Fatalf("error parsing result") } // Should return [{1:1}] - if resultSlice, ok := result.([]any); ok { - if len(resultSlice) != 1 { - t.Errorf("Expected 1 result from long query, got %d", len(resultSlice)) - } + if len(res) != 1 { + t.Errorf("Expected 1 result from long query, got %d", len(res)) } }) t.Run("NullValues", func(t *testing.T) { - tableName := "test_nulls_" + strings.ReplaceAll(uuid.New().String(), "-", "") createSQL := fmt.Sprintf(` CREATE TABLE %s ( id UInt32, @@ -844,40 +737,35 @@ func TestClickHouseEdgeCases(t *testing.T) { t.Fatalf("Failed to insert null value: %v", err) } - toolConfig := clickhousesql.Config{ - Name: "test-null-values", - Kind: "clickhouse-sql", - Source: "test-clickhouse", - Description: "Test null values", - Statement: fmt.Sprintf("SELECT * FROM %s ORDER BY id", tableName), + api := "http://127.0.0.1:5000/api/tool/test-null-values/invoke" + resp, respBody := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer([]byte(`{}`)), nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) + var body map[string]interface{} + err := json.Unmarshal(respBody, &body) if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) + t.Fatalf("error parsing response body") } - result, err := tool.Invoke(ctx, nil, parameters.ParamValues{}, "") - if err != nil { - t.Fatalf("Failed to select null values: %v", err) - } - - resultSlice, ok := result.([]any) + got, ok := body["result"].(string) if !ok { - t.Fatalf("Expected result to be []any, got %T", result) + t.Fatalf("unable to find result in response body") } - if len(resultSlice) != 2 { - t.Errorf("Expected 2 results, got %d", len(resultSlice)) + var res []any + err = json.Unmarshal([]byte(got), &res) + if err != nil { + t.Fatalf("error parsing result") + } + + if len(res) != 2 { + t.Errorf("Expected 2 result from long query, got %d", len(res)) } // Check that null is properly handled - if firstRow, ok := resultSlice[0].(map[string]any); ok { + if firstRow, ok := res[0].(map[string]any); ok { if _, hasNullableField := firstRow["nullable_field"]; !hasNullableField { t.Error("Expected nullable_field in result") } @@ -885,47 +773,38 @@ func TestClickHouseEdgeCases(t *testing.T) { }) t.Run("ConcurrentQueries", func(t *testing.T) { - toolConfig := clickhousesql.Config{ - Name: "test-concurrent", - Kind: "clickhouse-sql", - Source: "test-clickhouse", - Description: "Test concurrent queries", - Statement: "SELECT number FROM system.numbers LIMIT ?", - Parameters: parameters.Parameters{ - parameters.NewIntParameter("limit", "Limit"), - }, - } - - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } - // Run multiple queries concurrently done := make(chan bool, 5) for i := 0; i < 5; i++ { go func(n int) { defer func() { done <- true }() - params := parameters.ParamValues{ - {Name: "limit", Value: n + 1}, + params := fmt.Sprintf(`{"limit": %d}`, n+1) + api := "http://127.0.0.1:5000/api/tool/test-concurrent/invoke" + resp, respBody := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer([]byte(params)), nil) + if resp.StatusCode != http.StatusOK { + t.Errorf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } - result, err := tool.Invoke(ctx, nil, params, "") + var body map[string]interface{} + err := json.Unmarshal(respBody, &body) if err != nil { - t.Errorf("Concurrent query %d failed: %v", n, err) - return + t.Errorf("error parsing response body") } - if resultSlice, ok := result.([]any); ok { - if len(resultSlice) != n+1 { - t.Errorf("Query %d: expected %d results, got %d", n, n+1, len(resultSlice)) - } + got, ok := body["result"].(string) + if !ok { + t.Errorf("unable to find result in response body") + } + + var res []any + err = json.Unmarshal([]byte(got), &res) + if err != nil { + t.Errorf("error parsing result") + } + + if len(res) != n+1 { + t.Errorf("Query %d: expected %d results, got %d", n, n+1, len(res)) } }(i) } @@ -939,25 +818,6 @@ func TestClickHouseEdgeCases(t *testing.T) { t.Logf("✅ Edge case tests completed successfully") } -func createMockSource(t *testing.T, pool *sql.DB) sources.Source { - config := clickhouse.Config{ - Host: ClickHouseHost, - Port: ClickHousePort, - Database: ClickHouseDatabase, - User: ClickHouseUser, - Password: ClickHousePass, - Protocol: ClickHouseProtocol, - Secure: false, - } - - source, err := config.Initialize(context.Background(), noop.NewTracerProvider().Tracer("")) - if err != nil { - t.Fatalf("Failed to initialize source: %v", err) - } - - return source -} - // getClickHouseSQLParamToolInfo returns statements and param for my-tool clickhouse-sql kind func getClickHouseSQLParamToolInfo(tableName string) (string, string, string, string, string, string, []any) { createStatement := fmt.Sprintf("CREATE TABLE %s (id UInt32, name String) ENGINE = Memory", tableName) @@ -1036,44 +896,70 @@ func TestClickHouseListDatabasesTool(t *testing.T) { _, _ = pool.ExecContext(ctx, fmt.Sprintf("DROP DATABASE IF EXISTS %s", testDBName)) }() + toolsFile := map[string]any{ + "sources": map[string]any{ + "my-instance": getClickHouseVars(t), + }, + "tools": map[string]any{ + "test-list-databases": map[string]any{ + "kind": "clickhouse-list-databases", + "source": "my-instance", + "description": "Test listing databases", + }, + "test-invalid-source": map[string]any{ + "kind": "clickhouse-list-databases", + "source": "non-existent-source", + "description": "Test with invalid source", + }, + }, + } + + var args []string + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + t.Run("ListDatabases", func(t *testing.T) { - toolConfig := clickhouselistdatabases.Config{ - Name: "test-list-databases", - Kind: "clickhouse-list-databases", - Source: "test-clickhouse", - Description: "Test listing databases", + api := "http://127.0.0.1:5000/api/tool/test-list-databases/invoke" + resp, respBody := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer([]byte(`{}`)), nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) + var body map[string]interface{} + err := json.Unmarshal(respBody, &body) if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) + t.Fatalf("error parsing response body") } - params := parameters.ParamValues{} - - result, err := tool.Invoke(ctx, nil, params, "") - if err != nil { - t.Fatalf("Failed to list databases: %v", err) - } - - databases, ok := result.([]map[string]any) + databases, ok := body["result"].(string) if !ok { - t.Fatalf("Expected result to be []map[string]any, got %T", result) + t.Fatalf("unable to find result in response body") + } + var res []map[string]any + err = json.Unmarshal([]byte(databases), &res) + if err != nil { + t.Errorf("error parsing result") } // Should contain at least the default database and our test database - system and default - if len(databases) < 2 { - t.Errorf("Expected at least 2 databases, got %d", len(databases)) + if len(res) < 2 { + t.Errorf("Expected at least 2 databases, got %d", len(res)) } found := false foundDefault := false - for _, db := range databases { + for _, db := range res { if name, ok := db["name"].(string); ok { if name == testDBName { found = true @@ -1095,21 +981,12 @@ func TestClickHouseListDatabasesTool(t *testing.T) { }) t.Run("ListDatabasesWithInvalidSource", func(t *testing.T) { - toolConfig := clickhouselistdatabases.Config{ - Name: "test-invalid-source", - Kind: "clickhouse-list-databases", - Source: "non-existent-source", - Description: "Test with invalid source", + api := "http://127.0.0.1:5000/api/tool/test-invalid-source/invoke" + resp, _ := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer([]byte(`{}`)), nil) + if resp.StatusCode == http.StatusOK { + t.Fatalf("expected error for non-existent source, but got 200 OK") } - sourcesMap := map[string]sources.Source{} - - _, err := toolConfig.Initialize(sourcesMap) - if err == nil { - t.Error("Expected error for non-existent source, got nil") - } else { - t.Logf("Got expected error for invalid source: %v", err) - } }) t.Logf("✅ clickhouse-list-databases tool tests completed successfully") @@ -1148,46 +1025,71 @@ func TestClickHouseListTablesTool(t *testing.T) { t.Fatalf("Failed to create test table 2: %v", err) } + toolsFile := map[string]any{ + "sources": map[string]any{ + "my-instance": getClickHouseVars(t), + }, + "tools": map[string]any{ + "test-list-tables": map[string]any{ + "kind": "clickhouse-list-tables", + "source": "my-instance", + "description": "Test listing tables", + }, + "test-invalid-source": map[string]any{ + "kind": "clickhouse-list-tables", + "source": "non-existent-source", + "description": "Test with invalid source", + }, + }, + } + + var args []string + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + t.Run("ListTables", func(t *testing.T) { - toolConfig := clickhouselisttables.Config{ - Name: "test-list-tables", - Kind: "clickhouse-list-tables", - Source: "test-clickhouse", - Description: "Test listing tables", + api := "http://127.0.0.1:5000/api/tool/test-list-tables/invoke" + params := fmt.Sprintf(`{"database": "%s"}`, testDBName) + resp, respBody := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer([]byte(params)), nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) + var body map[string]interface{} + err := json.Unmarshal(respBody, &body) if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) + t.Fatalf("error parsing response body") } - params := parameters.ParamValues{ - {Name: "database", Value: testDBName}, - } - - result, err := tool.Invoke(ctx, nil, params, "") - if err != nil { - t.Fatalf("Failed to list tables: %v", err) - } - - tables, ok := result.([]map[string]any) + tables, ok := body["result"].(string) if !ok { - t.Fatalf("Expected result to be []map[string]any, got %T", result) + t.Fatalf("Expected result to be []map[string]any, got %T", tables) + } + var res []map[string]any + err = json.Unmarshal([]byte(tables), &res) + if err != nil { + t.Errorf("error parsing result") } // Should contain exactly 2 tables that we created - if len(tables) != 2 { - t.Errorf("Expected 2 tables, got %d", len(tables)) + if len(res) != 2 { + t.Errorf("Expected 2 tables, got %d", len(res)) } foundTable1 := false foundTable2 := false - for _, table := range tables { + for _, table := range res { if name, ok := table["name"].(string); ok { if name == testTable1 { foundTable1 = true @@ -1215,48 +1117,18 @@ func TestClickHouseListTablesTool(t *testing.T) { }) t.Run("ListTablesWithMissingDatabase", func(t *testing.T) { - toolConfig := clickhouselisttables.Config{ - Name: "test-list-tables-missing-db", - Kind: "clickhouse-list-tables", - Source: "test-clickhouse", - Description: "Test listing tables without database parameter", - } - - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } - - params := parameters.ParamValues{} - - _, err = tool.Invoke(ctx, nil, params, "") - if err == nil { - t.Error("Expected error for missing database parameter, got nil") - } else { - t.Logf("Got expected error for missing database: %v", err) + api := "http://127.0.0.1:5000/api/tool/test-list-tables/invoke" + resp, _ := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer([]byte(`{}`)), nil) + if resp.StatusCode == http.StatusOK { + t.Error("Expected error for missing database parameter, but got 200 OK") } }) t.Run("ListTablesWithInvalidSource", func(t *testing.T) { - toolConfig := clickhouselisttables.Config{ - Name: "test-invalid-source", - Kind: "clickhouse-list-tables", - Source: "non-existent-source", - Description: "Test with invalid source", - } - - sourcesMap := map[string]sources.Source{} - - _, err := toolConfig.Initialize(sourcesMap) - if err == nil { - t.Error("Expected error for non-existent source, got nil") - } else { - t.Logf("Got expected error for invalid source: %v", err) + api := "http://127.0.0.1:5000/api/tool/test-invalid-source/invoke" + resp, _ := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer([]byte(`{}`)), nil) + if resp.StatusCode == http.StatusOK { + t.Error("Expected error for non-existent source, but got 200 OK") } }) diff --git a/tests/cloudgda/cloud_gda_integration_test.go b/tests/cloudgda/cloud_gda_integration_test.go new file mode 100644 index 0000000000..3a7c8ad07f --- /dev/null +++ b/tests/cloudgda/cloud_gda_integration_test.go @@ -0,0 +1,233 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudgda_test + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "regexp" + "strings" + "testing" + "time" + + "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/cloudgda" + "github.com/googleapis/genai-toolbox/tests" +) + +var ( + cloudGdaToolKind = "cloud-gemini-data-analytics-query" +) + +type cloudGdaTransport struct { + transport http.RoundTripper + url *url.URL +} + +func (t *cloudGdaTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if strings.HasPrefix(req.URL.String(), "https://geminidataanalytics.googleapis.com") { + req.URL.Scheme = t.url.Scheme + req.URL.Host = t.url.Host + } + return t.transport.RoundTrip(req) +} + +type masterHandler struct { + t *testing.T +} + +func (h *masterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.UserAgent(), "genai-toolbox/") { + h.t.Errorf("User-Agent header not found") + } + + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Verify URL structure + // Expected: /v1beta/projects/{project}/locations/global:queryData + if !strings.Contains(r.URL.Path, ":queryData") || !strings.Contains(r.URL.Path, "locations/global") { + h.t.Errorf("unexpected URL path: %s", r.URL.Path) + http.Error(w, "Not found", http.StatusNotFound) + return + } + + var reqBody cloudgda.QueryDataRequest + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + h.t.Fatalf("failed to decode request body: %v", err) + } + + if reqBody.Prompt == "" { + http.Error(w, "missing prompt", http.StatusBadRequest) + return + } + + response := map[string]any{ + "queryResult": "SELECT * FROM table;", + "naturalLanguageAnswer": "Here is the answer.", + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func TestCloudGdaToolEndpoints(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + handler := &masterHandler{t: t} + server := httptest.NewServer(handler) + defer server.Close() + + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("failed to parse server URL: %v", err) + } + + originalTransport := http.DefaultClient.Transport + if originalTransport == nil { + originalTransport = http.DefaultTransport + } + http.DefaultClient.Transport = &cloudGdaTransport{ + transport: originalTransport, + url: serverURL, + } + t.Cleanup(func() { + http.DefaultClient.Transport = originalTransport + }) + + var args []string + toolsFile := getCloudGdaToolsConfig() + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + toolName := "cloud-gda-query" + + // 1. RunToolGetTestByName + expectedManifest := map[string]any{ + toolName: map[string]any{ + "description": "Test GDA Tool", + "parameters": []any{ + map[string]any{ + "name": "prompt", + "type": "string", + "description": "The natural language question to ask.", + "required": true, + "authSources": []any{}, + }, + }, + "authRequired": []any{}, + }, + } + tests.RunToolGetTestByName(t, toolName, expectedManifest) + + // 2. RunToolInvokeParametersTest + params := []byte(`{"prompt": "test question"}`) + tests.RunToolInvokeParametersTest(t, toolName, params, "\"queryResult\":\"SELECT * FROM table;\"") + + // 3. Manual MCP Tool Call Test + // Initialize MCP session + sessionId := tests.RunInitialize(t, "2024-11-05") + + // Construct MCP Request + mcpReq := jsonrpc.JSONRPCRequest{ + Jsonrpc: "2.0", + Id: "test-mcp-call", + Request: jsonrpc.Request{ + Method: "tools/call", + }, + Params: map[string]any{ + "name": toolName, + "arguments": map[string]any{ + "prompt": "test question", + }, + }, + } + reqBytes, _ := json.Marshal(mcpReq) + + headers := map[string]string{} + if sessionId != "" { + headers["Mcp-Session-Id"] = sessionId + } + + // Send Request + resp, respBody := tests.RunRequest(t, http.MethodPost, "http://127.0.0.1:5000/mcp", bytes.NewBuffer(reqBytes), headers) + + if resp.StatusCode != http.StatusOK { + t.Fatalf("MCP request failed with status %d: %s", resp.StatusCode, string(respBody)) + } + + // Check Response + respStr := string(respBody) + if !strings.Contains(respStr, "SELECT * FROM table;") { + t.Errorf("MCP response does not contain expected query result: %s", respStr) + } +} + +func getCloudGdaToolsConfig() map[string]any { + // Mocked responses and a dummy `projectId` are used in this integration + // test due to limited project-specific allowlisting. API functionality is + // verified via internal monitoring; this test specifically validates the + // integration flow between the source and the tool. + return map[string]any{ + "sources": map[string]any{ + "my-gda-source": map[string]any{ + "kind": "cloud-gemini-data-analytics", + "projectId": "test-project", + }, + }, + "tools": map[string]any{ + "cloud-gda-query": map[string]any{ + "kind": cloudGdaToolKind, + "source": "my-gda-source", + "description": "Test GDA Tool", + "location": "us-central1", + "context": map[string]any{ + "datasourceReferences": map[string]any{ + "spannerReference": map[string]any{ + "databaseReference": map[string]any{ + "projectId": "test-project", + "instanceId": "test-instance", + "databaseId": "test-db", + "engine": "GOOGLE_SQL", + }, + }, + }, + }, + }, + }, + } +} diff --git a/tests/cloudmonitoring/cloud_monitoring_integration_test.go b/tests/cloudmonitoring/cloud_monitoring_integration_test.go index 40bfa26234..f5833244a6 100644 --- a/tests/cloudmonitoring/cloud_monitoring_integration_test.go +++ b/tests/cloudmonitoring/cloud_monitoring_integration_test.go @@ -53,8 +53,6 @@ func TestTool_Invoke(t *testing.T) { Description: "Test Cloudmonitoring Tool", }, AllParams: parameters.Parameters{}, - BaseURL: server.URL, - Client: &http.Client{}, } // Define the test parameters @@ -99,8 +97,6 @@ func TestTool_Invoke_Error(t *testing.T) { Description: "Test Cloudmonitoring Tool", }, AllParams: parameters.Parameters{}, - BaseURL: server.URL, - Client: &http.Client{}, } // Define the test parameters diff --git a/tests/cloudsqlmysql/cloud_sql_mysql_integration_test.go b/tests/cloudsqlmysql/cloud_sql_mysql_integration_test.go index 192c779ea9..55b3035868 100644 --- a/tests/cloudsqlmysql/cloud_sql_mysql_integration_test.go +++ b/tests/cloudsqlmysql/cloud_sql_mysql_integration_test.go @@ -163,6 +163,7 @@ func TestCloudSQLMySQLToolEndpoints(t *testing.T) { const expectedOwner = "'toolbox-identity'@'%'" tests.RunMySQLListTablesTest(t, CloudSQLMySQLDatabase, tableNameParam, tableNameAuth, expectedOwner) tests.RunMySQLListActiveQueriesTest(t, ctx, pool) + tests.RunMySQLGetQueryPlanTest(t, ctx, pool, CloudSQLMySQLDatabase, tableNameParam) } // Test connection with different IP type diff --git a/tests/common.go b/tests/common.go index e2887c5ed9..5ada5a6b32 100644 --- a/tests/common.go +++ b/tests/common.go @@ -448,6 +448,11 @@ func AddMySQLPrebuiltToolConfig(t *testing.T, config map[string]any) map[string] "source": "my-instance", "description": "Lists table fragmentation in the database.", } + tools["get_query_plan"] = map[string]any{ + "kind": "mysql-get-query-plan", + "source": "my-instance", + "description": "Gets the query plan for a SQL statement.", + } config["tools"] = tools return config } diff --git a/tests/mysql/mysql_integration_test.go b/tests/mysql/mysql_integration_test.go index 4cb81197be..113767fd1d 100644 --- a/tests/mysql/mysql_integration_test.go +++ b/tests/mysql/mysql_integration_test.go @@ -143,4 +143,5 @@ func TestMySQLToolEndpoints(t *testing.T) { tests.RunMySQLListActiveQueriesTest(t, ctx, pool) tests.RunMySQLListTablesMissingUniqueIndexes(t, ctx, pool, MySQLDatabase) tests.RunMySQLListTableFragmentationTest(t, MySQLDatabase, tableNameParam, tableNameAuth) + tests.RunMySQLGetQueryPlanTest(t, ctx, pool, MySQLDatabase, tableNameParam) } diff --git a/tests/oracle/oracle_integration_test.go b/tests/oracle/oracle_integration_test.go index 04f272a1b8..0021679e9e 100644 --- a/tests/oracle/oracle_integration_test.go +++ b/tests/oracle/oracle_integration_test.go @@ -43,6 +43,7 @@ func getOracleVars(t *testing.T) map[string]any { return map[string]any{ "kind": OracleSourceKind, "connectionString": OracleConnStr, + "useOCI": true, "user": OracleUser, "password": OraclePass, } @@ -50,9 +51,11 @@ func getOracleVars(t *testing.T) map[string]any { // Copied over from oracle.go func initOracleConnection(ctx context.Context, user, pass, connStr string) (*sql.DB, error) { - fullConnStr := fmt.Sprintf("oracle://%s:%s@%s", user, pass, connStr) + // Build the full Oracle connection string for godror driver + fullConnStr := fmt.Sprintf(`user="%s" password="%s" connectString="%s"`, + user, pass, connStr) - db, err := sql.Open("oracle", fullConnStr) + db, err := sql.Open("godror", fullConnStr) if err != nil { return nil, fmt.Errorf("unable to open Oracle connection: %w", err) } @@ -116,13 +119,15 @@ func TestOracleSimpleToolEndpoints(t *testing.T) { // Get configs for tests select1Want := "[{\"1\":1}]" - mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: ORA-00900: invalid SQL statement\n error occur at position: 0"}],"isError":true}}` + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: dpiStmt_execute: ORA-00900: invalid SQL statement"}],"isError":true}}` createTableStatement := `"CREATE TABLE t (id NUMBER GENERATED AS IDENTITY PRIMARY KEY, name VARCHAR2(255))"` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":1}"}]}}` // Run tests tests.RunToolGetTest(t) tests.RunToolInvokeTest(t, select1Want, + tests.DisableOptionalNullParamTest(), + tests.WithMyToolById4Want("[{\"id\":4,\"name\":\"\"}]"), tests.DisableArrayTest(), ) tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want) diff --git a/tests/spanner/spanner_integration_test.go b/tests/spanner/spanner_integration_test.go index 324738f6cb..4daf87a27e 100644 --- a/tests/spanner/spanner_integration_test.go +++ b/tests/spanner/spanner_integration_test.go @@ -277,7 +277,7 @@ func setupSpannerTable(t *testing.T, ctx context.Context, adminClient *database. // tear down test op, err = adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{ Database: dbString, - Statements: []string{fmt.Sprintf("DROP TABLE %s", tableName)}, + Statements: []string{fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)}, }) if err != nil { t.Errorf("unable to start drop %s operation: %s", tableName, err) @@ -310,7 +310,7 @@ func setupSpannerGraph(t *testing.T, ctx context.Context, adminClient *database. // tear down test op, err = adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{ Database: dbString, - Statements: []string{fmt.Sprintf("DROP PROPERTY GRAPH %s", graphName)}, + Statements: []string{fmt.Sprintf("DROP PROPERTY GRAPH IF EXISTS %s", graphName)}, }) if err != nil { t.Errorf("unable to start drop %s operation: %s", graphName, err) diff --git a/tests/tool.go b/tests/tool.go index 9fcd045d76..65a358ca5d 100644 --- a/tests/tool.go +++ b/tests/tool.go @@ -2401,10 +2401,10 @@ func RunPostgresListPgSettingsTest(t *testing.T, ctx context.Context, pool *pgxp // RunPostgresDatabaseStatsTest tests the database_stats tool by comparing API results // against a direct query to the database. func RunPostgresListDatabaseStatsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { - dbName1 := "test_db_stats_1" - dbOwner1 := "test_user1" - dbName2 := "test_db_stats_2" - dbOwner2 := "test_user2" + dbName1 := "test_db_stats_" + strings.ReplaceAll(uuid.NewString(), "-", "") + dbOwner1 := "test_user_" + strings.ReplaceAll(uuid.NewString(), "-", "") + dbName2 := "test_db_stats_" + strings.ReplaceAll(uuid.NewString(), "-", "") + dbOwner2 := "test_user_" + strings.ReplaceAll(uuid.NewString(), "-", "") cleanup1 := setUpDatabase(t, ctx, pool, dbName1, dbOwner1) defer cleanup1() @@ -3377,6 +3377,81 @@ func RunMySQLListTableFragmentationTest(t *testing.T, databaseName, tableNamePar } } +func RunMySQLGetQueryPlanTest(t *testing.T, ctx context.Context, pool *sql.DB, databaseName, tableNameParam string) { + // Create a simple query to explain + query := fmt.Sprintf("SELECT * FROM %s", tableNameParam) + + invokeTcs := []struct { + name string + requestBody io.Reader + wantStatusCode int + checkResult func(t *testing.T, result any) + }{ + { + name: "invoke get_query_plan with valid query", + requestBody: bytes.NewBufferString(fmt.Sprintf(`{"sql_statement": "%s"}`, query)), + wantStatusCode: http.StatusOK, + checkResult: func(t *testing.T, result any) { + resultMap, ok := result.(map[string]any) + if !ok { + t.Fatalf("result should be a map, got %T", result) + } + if _, ok := resultMap["query_block"]; !ok { + t.Errorf("result should contain 'query_block', got %v", resultMap) + } + }, + }, + { + name: "invoke get_query_plan with invalid query", + requestBody: bytes.NewBufferString(`{"sql_statement": "SELECT * FROM non_existent_table"}`), + wantStatusCode: http.StatusBadRequest, + checkResult: nil, + }, + } + + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + const api = "http://127.0.0.1:5000/api/tool/get_query_plan/invoke" + resp, respBytes := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) + if resp.StatusCode != tc.wantStatusCode { + t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBytes)) + } + if tc.wantStatusCode != http.StatusOK { + return + } + + var bodyWrapper map[string]json.RawMessage + + if err := json.Unmarshal(respBytes, &bodyWrapper); err != nil { + t.Fatalf("error parsing response wrapper: %s, body: %s", err, string(respBytes)) + } + + resultJSON, ok := bodyWrapper["result"] + if !ok { + t.Fatal("unable to find 'result' in response body") + } + + var resultString string + if err := json.Unmarshal(resultJSON, &resultString); err != nil { + if string(resultJSON) == "null" { + resultString = "null" + } else { + t.Fatalf("'result' is not a JSON-encoded string: %s", err) + } + } + + var got map[string]any + if err := json.Unmarshal([]byte(resultString), &got); err != nil { + t.Fatalf("failed to unmarshal actual result string: %v", err) + } + + if tc.checkResult != nil { + tc.checkResult(t, got) + } + }) + } +} + // RunMSSQLListTablesTest run tests againsts the mssql-list-tables tools. func RunMSSQLListTablesTest(t *testing.T, tableNameParam, tableNameAuth string) { // TableNameParam columns to construct want.