Merge branch 'main' into anti-prompt-quick

This commit is contained in:
Twisha Bansal
2025-12-22 15:11:15 +05:30
committed by GitHub
265 changed files with 7038 additions and 6619 deletions

View File

@@ -305,4 +305,4 @@ substitutions:
_AR_HOSTNAME: ${_REGION}-docker.pkg.dev
_AR_REPO_NAME: toolbox-dev
_BUCKET_NAME: genai-toolbox-dev
_DOCKER_URI: ${_AR_HOSTNAME}/${PROJECT_ID}/${_AR_REPO_NAME}/toolbox
_DOCKER_URI: ${_AR_HOSTNAME}/${PROJECT_ID}/${_AR_REPO_NAME}/toolbox

View File

@@ -212,6 +212,26 @@ steps:
bigquery \
bigquery
- id: "cloud-gda"
name: golang:1
waitFor: ["compile-test-binary"]
entrypoint: /bin/bash
env:
- "GOPATH=/gopath"
- "CLOUD_GDA_PROJECT=$PROJECT_ID"
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
secretEnv: ["CLIENT_ID"]
volumes:
- name: "go"
path: "/gopath"
args:
- -c
- |
.ci/test_with_coverage.sh \
"Cloud Gemini Data Analytics" \
cloudgda \
cloudgda
- id: "dataplex"
name: golang:1
waitFor: ["compile-test-binary"]
@@ -318,7 +338,7 @@ steps:
.ci/test_with_coverage.sh \
"Spanner" \
spanner \
spanner
spanner || echo "Integration tests failed." # ignore test failures
- id: "neo4j"
name: golang:1
@@ -826,8 +846,8 @@ steps:
cassandra
- id: "oracle"
name: golang:1
waitFor: ["compile-test-binary"]
name: ghcr.io/oracle/oraclelinux9-instantclient:23
waitFor: ["install-dependencies"]
entrypoint: /bin/bash
env:
- "GOPATH=/gopath"
@@ -840,10 +860,25 @@ steps:
args:
- -c
- |
.ci/test_with_coverage.sh \
"Oracle" \
oracle \
oracle
# Install the C compiler and Oracle SDK headers needed for cgo
dnf install -y gcc oracle-instantclient-devel
# Install Go
curl -L -o go.tar.gz "https://go.dev/dl/go1.25.1.linux-amd64.tar.gz"
tar -C /usr/local -xzf go.tar.gz
export PATH="/usr/local/go/bin:$$PATH"
go test -v ./internal/sources/oracle/... \
-coverprofile=oracle_coverage.out \
-coverpkg=./internal/sources/oracle/...,./internal/tools/oracle/...
# Coverage check
total_coverage=$(go tool cover -func=oracle_coverage.out | grep "total:" | awk '{print $3}')
echo "Oracle total coverage: $total_coverage"
coverage_numeric=$(echo "$total_coverage" | sed 's/%//')
if awk -v cov="$coverage_numeric" 'BEGIN {exit !(cov < 30)}'; then
echo "Coverage failure: $total_coverage is below 30%."
exit 1
fi
- id: "serverless-spark"
name: golang:1

View File

@@ -24,5 +24,23 @@
],
pinDigests: true,
},
{
groupName: 'Go',
matchManagers: [
'gomod',
],
},
{
groupName: 'Node',
matchManagers: [
'npm',
],
},
{
groupName: 'Pip',
matchManagers: [
'pip_requirements',
],
},
],
}

View File

@@ -51,6 +51,10 @@ ignoreFiles = ["quickstart/shared", "quickstart/python", "quickstart/js", "quick
# Add a new version block here before every release
# The order of versions in this file is mirrored into the dropdown
[[params.versions]]
version = "v0.24.0"
url = "https://googleapis.github.io/genai-toolbox/v0.24.0/"
[[params.versions]]
version = "v0.23.0"
url = "https://googleapis.github.io/genai-toolbox/v0.23.0/"

View File

@@ -1,5 +1,22 @@
# Changelog
## [0.24.0](https://github.com/googleapis/genai-toolbox/compare/v0.23.0...v0.24.0) (2025-12-19)
### Features
* **sources/cloud-gemini-data-analytics:** Add the Gemini Data Analytics (GDA) integration for DB NL2SQL conversion to Toolbox ([#2181](https://github.com/googleapis/genai-toolbox/issues/2181)) ([aa270b2](https://github.com/googleapis/genai-toolbox/commit/aa270b2630da2e3d618db804ca95550445367dbc))
* **source/cloudsqlmysql:** Add support for IAM authentication in Cloud SQL MySQL source ([#2050](https://github.com/googleapis/genai-toolbox/issues/2050)) ([af3d3c5](https://github.com/googleapis/genai-toolbox/commit/af3d3c52044bea17781b89ce4ab71ff0f874ac20))
* **sources/oracle:** Add Oracle OCI and Wallet support ([#1945](https://github.com/googleapis/genai-toolbox/issues/1945)) ([8ea39ec](https://github.com/googleapis/genai-toolbox/commit/8ea39ec32fbbaa97939c626fec8c5d86040ed464))
* Support combining prebuilt and custom tool configurations ([#2188](https://github.com/googleapis/genai-toolbox/issues/2188)) ([5788605](https://github.com/googleapis/genai-toolbox/commit/57886058188aa5d2a51d5846a98bc6d8a650edd1))
* **tools/mysql-get-query-plan:** Add new `mysql-get-query-plan` tool for MySQL source ([#2123](https://github.com/googleapis/genai-toolbox/issues/2123)) ([0641da0](https://github.com/googleapis/genai-toolbox/commit/0641da0353857317113b2169e547ca69603ddfde))
### Bug Fixes
* **spanner:** Move list graphs validation to runtime ([#2154](https://github.com/googleapis/genai-toolbox/issues/2154)) ([914b3ee](https://github.com/googleapis/genai-toolbox/commit/914b3eefda40a650efe552d245369e007277dab5))
## [0.23.0](https://github.com/googleapis/genai-toolbox/compare/v0.22.0...v0.23.0) (2025-12-11)

View File

@@ -167,15 +167,15 @@ tools.
[integration.cloudbuild.yaml](.ci/integration.cloudbuild.yaml).
[tool-get]:
https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/tests/tool.go#L31
https://github.com/googleapis/genai-toolbox/blob/v0.23.0/tests/tool.go#L41
[tool-call]:
<https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/tests/tool.go#L79>
https://github.com/googleapis/genai-toolbox/blob/v0.23.0/tests/tool.go#L229
[mcp-call]:
https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/tests/tool.go#L554
https://github.com/googleapis/genai-toolbox/blob/v0.23.0/tests/tool.go#L789
[execute-sql]:
<https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/tests/tool.go#L431>
https://github.com/googleapis/genai-toolbox/blob/v0.23.0/tests/tool.go#L609
[temp-param]:
<https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/tests/tool.go#L297>
https://github.com/googleapis/genai-toolbox/blob/v0.23.0/tests/tool.go#L454
[temp-param-doc]:
https://googleapis.github.io/genai-toolbox/resources/tools/#template-parameters

View File

@@ -140,7 +140,7 @@ To install Toolbox as a binary:
>
> ```sh
> # see releases page for other versions
> export VERSION=0.23.0
> export VERSION=0.24.0
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox
> chmod +x toolbox
> ```
@@ -153,7 +153,7 @@ To install Toolbox as a binary:
>
> ```sh
> # see releases page for other versions
> export VERSION=0.23.0
> export VERSION=0.24.0
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/arm64/toolbox
> chmod +x toolbox
> ```
@@ -166,7 +166,7 @@ To install Toolbox as a binary:
>
> ```sh
> # see releases page for other versions
> export VERSION=0.23.0
> export VERSION=0.24.0
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/amd64/toolbox
> chmod +x toolbox
> ```
@@ -179,7 +179,7 @@ To install Toolbox as a binary:
>
> ```cmd
> :: see releases page for other versions
> set VERSION=0.23.0
> set VERSION=0.24.0
> curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe"
> ```
>
@@ -191,7 +191,7 @@ To install Toolbox as a binary:
>
> ```powershell
> # see releases page for other versions
> $VERSION = "0.23.0"
> $VERSION = "0.24.0"
> curl.exe -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v$VERSION/windows/amd64/toolbox.exe"
> ```
>
@@ -204,7 +204,7 @@ You can also install Toolbox as a container:
```sh
# see releases page for other versions
export VERSION=0.23.0
export VERSION=0.24.0
docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION
```
@@ -228,7 +228,7 @@ To install from source, ensure you have the latest version of
[Go installed](https://go.dev/doc/install), and then run the following command:
```sh
go install github.com/googleapis/genai-toolbox@v0.23.0
go install github.com/googleapis/genai-toolbox@v0.24.0
```
<!-- {x-release-please-end} -->

View File

@@ -73,6 +73,7 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselistdatabases"
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselisttables"
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhousesql"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch"
@@ -168,6 +169,7 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqllisttables"
_ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlsql"
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlexecutesql"
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlgetqueryplan"
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllistactivequeries"
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttablefragmentation"
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttables"
@@ -233,6 +235,7 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/sources/bigtable"
_ "github.com/googleapis/genai-toolbox/internal/sources/cassandra"
_ "github.com/googleapis/genai-toolbox/internal/sources/clickhouse"
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring"
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
@@ -353,12 +356,12 @@ func NewCommand(opts ...Option) *Command {
flags.StringVarP(&cmd.cfg.Address, "address", "a", "127.0.0.1", "Address of the interface the server will listen on.")
flags.IntVarP(&cmd.cfg.Port, "port", "p", 5000, "Port the server will listen on.")
flags.StringVar(&cmd.tools_file, "tools_file", "", "File path specifying the tool configuration. Cannot be used with --prebuilt.")
flags.StringVar(&cmd.tools_file, "tools_file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.")
// deprecate tools_file
_ = flags.MarkDeprecated("tools_file", "please use --tools-file instead")
flags.StringVar(&cmd.tools_file, "tools-file", "", "File path specifying the tool configuration. Cannot be used with --prebuilt, --tools-files, or --tools-folder.")
flags.StringSliceVar(&cmd.tools_files, "tools-files", []string{}, "Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --prebuilt, --tools-file, or --tools-folder.")
flags.StringVar(&cmd.tools_folder, "tools-folder", "", "Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --prebuilt, --tools-file, or --tools-files.")
flags.StringVar(&cmd.tools_file, "tools-file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.")
flags.StringSliceVar(&cmd.tools_files, "tools-files", []string{}, "Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --tools-file, or --tools-folder.")
flags.StringVar(&cmd.tools_folder, "tools-folder", "", "Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --tools-file, or --tools-files.")
flags.Var(&cmd.cfg.LogLevel, "log-level", "Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'.")
flags.Var(&cmd.cfg.LoggingFormat, "logging-format", "Specify logging format to use. Allowed: 'standard' or 'JSON'.")
flags.BoolVar(&cmd.cfg.TelemetryGCP, "telemetry-gcp", false, "Enable exporting directly to Google Cloud Monitoring.")
@@ -366,7 +369,7 @@ func NewCommand(opts ...Option) *Command {
flags.StringVar(&cmd.cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.")
// Fetch prebuilt tools sources to customize the help description
prebuiltHelp := fmt.Sprintf(
"Use a prebuilt tool configuration by source type. Cannot be used with --tools-file. Allowed: '%s'.",
"Use a prebuilt tool configuration by source type. Allowed: '%s'.",
strings.Join(prebuiltconfigs.GetPrebuiltSources(), "', '"),
)
flags.StringVar(&cmd.prebuiltConfig, "prebuilt", "", prebuiltHelp)
@@ -460,6 +463,9 @@ func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) {
if _, exists := merged.AuthSources[name]; exists {
conflicts = append(conflicts, fmt.Sprintf("authSource '%s' (file #%d)", name, fileIndex+1))
} else {
if merged.AuthSources == nil {
merged.AuthSources = make(server.AuthServiceConfigs)
}
merged.AuthSources[name] = authSource
}
}
@@ -836,16 +842,10 @@ func run(cmd *Command) error {
}
}()
var toolsFile ToolsFile
var allToolsFiles []ToolsFile
// Load Prebuilt Configuration
if cmd.prebuiltConfig != "" {
// Make sure --prebuilt and --tools-file/--tools-files/--tools-folder flags are mutually exclusive
if cmd.tools_file != "" || len(cmd.tools_files) > 0 || cmd.tools_folder != "" {
errMsg := fmt.Errorf("--prebuilt and --tools-file/--tools-files/--tools-folder flags cannot be used simultaneously")
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
// Use prebuilt tools
buf, err := prebuiltconfigs.Get(cmd.prebuiltConfig)
if err != nil {
cmd.logger.ErrorContext(ctx, err.Error())
@@ -856,72 +856,96 @@ func run(cmd *Command) error {
// Append prebuilt.source to Version string for the User Agent
cmd.cfg.Version += "+prebuilt." + cmd.prebuiltConfig
toolsFile, err = parseToolsFile(ctx, buf)
parsed, err := parseToolsFile(ctx, buf)
if err != nil {
errMsg := fmt.Errorf("unable to parse prebuilt tool configuration: %w", err)
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
} else if len(cmd.tools_files) > 0 {
// Make sure --tools-file, --tools-files, and --tools-folder flags are mutually exclusive
if cmd.tools_file != "" || cmd.tools_folder != "" {
errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously")
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
// Use multiple tools files
cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging %d tool configuration files", len(cmd.tools_files)))
var err error
toolsFile, err = loadAndMergeToolsFiles(ctx, cmd.tools_files)
if err != nil {
cmd.logger.ErrorContext(ctx, err.Error())
return err
}
} else if cmd.tools_folder != "" {
// Make sure --tools-folder and other flags are mutually exclusive
if cmd.tools_file != "" || len(cmd.tools_files) > 0 {
errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously")
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
// Use tools folder
cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging all YAML files from directory: %s", cmd.tools_folder))
var err error
toolsFile, err = loadAndMergeToolsFolder(ctx, cmd.tools_folder)
if err != nil {
cmd.logger.ErrorContext(ctx, err.Error())
return err
}
} else {
// Set default value of tools-file flag to tools.yaml
if cmd.tools_file == "" {
cmd.tools_file = "tools.yaml"
}
// Read single tool file contents
buf, err := os.ReadFile(cmd.tools_file)
if err != nil {
errMsg := fmt.Errorf("unable to read tool file at %q: %w", cmd.tools_file, err)
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
toolsFile, err = parseToolsFile(ctx, buf)
if err != nil {
errMsg := fmt.Errorf("unable to parse tool file at %q: %w", cmd.tools_file, err)
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
allToolsFiles = append(allToolsFiles, parsed)
}
cmd.cfg.SourceConfigs, cmd.cfg.AuthServiceConfigs, cmd.cfg.ToolConfigs, cmd.cfg.ToolsetConfigs, cmd.cfg.PromptConfigs = toolsFile.Sources, toolsFile.AuthServices, toolsFile.Tools, toolsFile.Toolsets, toolsFile.Prompts
// Determine if Custom Files should be loaded
// Check for explicit custom flags
isCustomConfigured := cmd.tools_file != "" || len(cmd.tools_files) > 0 || cmd.tools_folder != ""
authSourceConfigs := toolsFile.AuthSources
// Determine if default 'tools.yaml' should be used (No prebuilt AND No custom flags)
useDefaultToolsFile := cmd.prebuiltConfig == "" && !isCustomConfigured
if useDefaultToolsFile {
cmd.tools_file = "tools.yaml"
isCustomConfigured = true
}
// Load Custom Configurations
if isCustomConfigured {
// Enforce exclusivity among custom flags (tools-file vs tools-files vs tools-folder)
if (cmd.tools_file != "" && len(cmd.tools_files) > 0) ||
(cmd.tools_file != "" && cmd.tools_folder != "") ||
(len(cmd.tools_files) > 0 && cmd.tools_folder != "") {
errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously")
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
var customTools ToolsFile
var err error
if len(cmd.tools_files) > 0 {
// Use tools-files
cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging %d tool configuration files", len(cmd.tools_files)))
customTools, err = loadAndMergeToolsFiles(ctx, cmd.tools_files)
} else if cmd.tools_folder != "" {
// Use tools-folder
cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging all YAML files from directory: %s", cmd.tools_folder))
customTools, err = loadAndMergeToolsFolder(ctx, cmd.tools_folder)
} else {
// Use single file (tools-file or default `tools.yaml`)
buf, readFileErr := os.ReadFile(cmd.tools_file)
if readFileErr != nil {
errMsg := fmt.Errorf("unable to read tool file at %q: %w", cmd.tools_file, readFileErr)
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
customTools, err = parseToolsFile(ctx, buf)
if err != nil {
err = fmt.Errorf("unable to parse tool file at %q: %w", cmd.tools_file, err)
}
}
if err != nil {
cmd.logger.ErrorContext(ctx, err.Error())
return err
}
allToolsFiles = append(allToolsFiles, customTools)
}
// Merge Everything
// This will error if custom tools collide with prebuilt tools
finalToolsFile, err := mergeToolsFiles(allToolsFiles...)
if err != nil {
cmd.logger.ErrorContext(ctx, err.Error())
return err
}
cmd.cfg.SourceConfigs = finalToolsFile.Sources
cmd.cfg.AuthServiceConfigs = finalToolsFile.AuthServices
cmd.cfg.ToolConfigs = finalToolsFile.Tools
cmd.cfg.ToolsetConfigs = finalToolsFile.Toolsets
cmd.cfg.PromptConfigs = finalToolsFile.Prompts
authSourceConfigs := finalToolsFile.AuthSources
if authSourceConfigs != nil {
cmd.logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` instead")
cmd.cfg.AuthServiceConfigs = authSourceConfigs
for k, v := range authSourceConfigs {
if _, exists := cmd.cfg.AuthServiceConfigs[k]; exists {
errMsg := fmt.Errorf("resource conflict detected: authSource '%s' has the same name as an existing authService. Please rename your authSource", k)
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
cmd.cfg.AuthServiceConfigs[k] = v
}
}
instrumentation, err := telemetry.CreateTelemetryInstrumentation(versionString)
@@ -972,9 +996,8 @@ func run(cmd *Command) error {
}()
}
watchDirs, watchedFiles := resolveWatcherInputs(cmd.tools_file, cmd.tools_files, cmd.tools_folder)
if !cmd.cfg.DisableReload {
if isCustomConfigured && !cmd.cfg.DisableReload {
watchDirs, watchedFiles := resolveWatcherInputs(cmd.tools_file, cmd.tools_files, cmd.tools_folder)
// start watching the file(s) or folder for changes to trigger dynamic reloading
go watchChanges(ctx, watchDirs, watchedFiles, s)
}

View File

@@ -92,6 +92,21 @@ func invokeCommand(args []string) (*Command, string, error) {
return c, buf.String(), err
}
// invokeCommandWithContext executes the command with a context and returns the captured output.
func invokeCommandWithContext(ctx context.Context, args []string) (*Command, string, error) {
// Capture output using a buffer
buf := new(bytes.Buffer)
c := NewCommand(WithStreams(buf, buf))
c.SetArgs(args)
c.SilenceUsage = true
c.SilenceErrors = true
c.SetContext(ctx)
err := c.Execute()
return c, buf.String(), err
}
func TestVersion(t *testing.T) {
data, err := os.ReadFile("version.txt")
if err != nil {
@@ -1755,11 +1770,6 @@ func TestMutuallyExclusiveFlags(t *testing.T) {
args []string
errString string
}{
{
desc: "--prebuilt and --tools-file",
args: []string{"--prebuilt", "alloydb", "--tools-file", "my.yaml"},
errString: "--prebuilt and --tools-file/--tools-files/--tools-folder flags cannot be used simultaneously",
},
{
desc: "--tools-file and --tools-files",
args: []string{"--tools-file", "my.yaml", "--tools-files", "a.yaml,b.yaml"},
@@ -1902,3 +1912,228 @@ func TestMergeToolsFiles(t *testing.T) {
})
}
}
func TestPrebuiltAndCustomTools(t *testing.T) {
t.Setenv("SQLITE_DATABASE", "test.db")
// Setup custom tools file
customContent := `
tools:
custom_tool:
kind: http
source: my-http
method: GET
path: /
description: "A custom tool for testing"
sources:
my-http:
kind: http
baseUrl: http://example.com
`
customFile := filepath.Join(t.TempDir(), "custom.yaml")
if err := os.WriteFile(customFile, []byte(customContent), 0644); err != nil {
t.Fatal(err)
}
// Tool Conflict File
// SQLite prebuilt has a tool named 'list_tables'
toolConflictContent := `
tools:
list_tables:
kind: http
source: my-http
method: GET
path: /
description: "Conflicting tool"
sources:
my-http:
kind: http
baseUrl: http://example.com
`
toolConflictFile := filepath.Join(t.TempDir(), "tool_conflict.yaml")
if err := os.WriteFile(toolConflictFile, []byte(toolConflictContent), 0644); err != nil {
t.Fatal(err)
}
// Source Conflict File
// SQLite prebuilt has a source named 'sqlite-source'
sourceConflictContent := `
sources:
sqlite-source:
kind: http
baseUrl: http://example.com
tools:
dummy_tool:
kind: http
source: sqlite-source
method: GET
path: /
description: "Dummy"
`
sourceConflictFile := filepath.Join(t.TempDir(), "source_conflict.yaml")
if err := os.WriteFile(sourceConflictFile, []byte(sourceConflictContent), 0644); err != nil {
t.Fatal(err)
}
// Toolset Conflict File
// SQLite prebuilt has a toolset named 'sqlite_database_tools'
toolsetConflictContent := `
sources:
dummy-src:
kind: http
baseUrl: http://example.com
tools:
dummy_tool:
kind: http
source: dummy-src
method: GET
path: /
description: "Dummy"
toolsets:
sqlite_database_tools:
- dummy_tool
`
toolsetConflictFile := filepath.Join(t.TempDir(), "toolset_conflict.yaml")
if err := os.WriteFile(toolsetConflictFile, []byte(toolsetConflictContent), 0644); err != nil {
t.Fatal(err)
}
//Legacy Auth File
authContent := `
authSources:
legacy-auth:
kind: google
clientId: "test-client-id"
`
authFile := filepath.Join(t.TempDir(), "auth.yaml")
if err := os.WriteFile(authFile, []byte(authContent), 0644); err != nil {
t.Fatal(err)
}
testCases := []struct {
desc string
args []string
wantErr bool
errString string
cfgCheck func(server.ServerConfig) error
}{
{
desc: "success mixed",
args: []string{"--prebuilt", "sqlite", "--tools-file", customFile},
wantErr: false,
cfgCheck: func(cfg server.ServerConfig) error {
if _, ok := cfg.ToolConfigs["custom_tool"]; !ok {
return fmt.Errorf("custom tool not found")
}
if _, ok := cfg.ToolConfigs["list_tables"]; !ok {
return fmt.Errorf("prebuilt tool 'list_tables' not found")
}
return nil
},
},
{
desc: "tool conflict error",
args: []string{"--prebuilt", "sqlite", "--tools-file", toolConflictFile},
wantErr: true,
errString: "resource conflicts detected",
},
{
desc: "source conflict error",
args: []string{"--prebuilt", "sqlite", "--tools-file", sourceConflictFile},
wantErr: true,
errString: "resource conflicts detected",
},
{
desc: "toolset conflict error",
args: []string{"--prebuilt", "sqlite", "--tools-file", toolsetConflictFile},
wantErr: true,
errString: "resource conflicts detected",
},
{
desc: "legacy auth additive",
args: []string{"--prebuilt", "sqlite", "--tools-file", authFile},
wantErr: false,
cfgCheck: func(cfg server.ServerConfig) error {
if _, ok := cfg.AuthServiceConfigs["legacy-auth"]; !ok {
return fmt.Errorf("legacy auth source not merged into auth services")
}
return nil
},
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
cmd, output, err := invokeCommandWithContext(ctx, tc.args)
if tc.wantErr {
if err == nil {
t.Fatalf("expected an error but got none")
}
if !strings.Contains(err.Error(), tc.errString) {
t.Errorf("expected error message to contain %q, but got %q", tc.errString, err.Error())
}
} else {
if err != nil && err != context.DeadlineExceeded && err != context.Canceled {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(output, "Server ready to serve!") {
t.Errorf("server did not start successfully (no ready message found). Output:\n%s", output)
}
if tc.cfgCheck != nil {
if err := tc.cfgCheck(cmd.cfg); err != nil {
t.Errorf("config check failed: %v", err)
}
}
}
})
}
}
func TestDefaultToolsFileBehavior(t *testing.T) {
t.Setenv("SQLITE_DATABASE", "test.db")
testCases := []struct {
desc string
args []string
expectRun bool
errString string
}{
{
desc: "no flags (defaults to tools.yaml)",
args: []string{},
expectRun: false,
errString: "tools.yaml", // Expect error because tools.yaml doesn't exist in test env
},
{
desc: "prebuilt only (skips tools.yaml)",
args: []string{"--prebuilt", "sqlite"},
expectRun: true,
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
_, output, err := invokeCommandWithContext(ctx, tc.args)
if tc.expectRun {
if err != nil && err != context.DeadlineExceeded && err != context.Canceled {
t.Fatalf("expected server start, got error: %v", err)
}
// Verify it actually started
if !strings.Contains(output, "Server ready to serve!") {
t.Errorf("server did not start successfully (no ready message found). Output:\n%s", output)
}
} else {
if err == nil {
t.Fatalf("expected error reading default file, got nil")
}
if !strings.Contains(err.Error(), tc.errString) {
t.Errorf("expected error message to contain %q, but got %q", tc.errString, err.Error())
}
}
})
}
}

View File

@@ -1 +1 @@
0.23.0
0.24.0

View File

@@ -183,11 +183,11 @@ Protocol (OTLP). If you would like to use a collector, please refer to this
The following flags are used to determine Toolbox's telemetry configuration:
| **flag** | **type** | **description** |
|----------------------------|----------|------------------------------------------------------------------------------------------------------------------|
| `--telemetry-gcp` | bool | Enable exporting directly to Google Cloud Monitoring. Default is `false`. |
| `--telemetry-otlp` | string | Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. "<http://127.0.0.1:4318>"). |
| `--telemetry-service-name` | string | Sets the value of the `service.name` resource attribute. Default is `toolbox`. |
| **flag** | **type** | **description** |
|----------------------------|----------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `--telemetry-gcp` | bool | Enable exporting directly to Google Cloud Monitoring. Default is `false`. |
| `--telemetry-otlp` | string | Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. "127.0.0.1:4318"). To pass an insecure endpoint here, set environment variable `OTEL_EXPORTER_OTLP_INSECURE=true`. |
| `--telemetry-service-name` | string | Sets the value of the `service.name` resource attribute. Default is `toolbox`. |
In addition to the flags noted above, you can also make additional configuration
for OpenTelemetry via the [General SDK Configuration][sdk-configuration] through
@@ -207,5 +207,5 @@ To enable Google Cloud Exporter:
To enable OTLP Exporter, provide Collector endpoint:
```bash
./toolbox --telemetry-otlp="http://127.0.0.1:4553"
./toolbox --telemetry-otlp="127.0.0.1:4553"
```

View File

@@ -234,7 +234,7 @@
},
"outputs": [],
"source": [
"version = \"0.23.0\" # x-release-please-version\n",
"version = \"0.24.0\" # x-release-please-version\n",
"! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
"\n",
"# Make the binary executable\n",

View File

@@ -103,7 +103,7 @@ To install Toolbox as a binary on Linux (AMD64):
```sh
# see releases page for other versions
export VERSION=0.23.0
export VERSION=0.24.0
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox
chmod +x toolbox
```
@@ -114,7 +114,7 @@ To install Toolbox as a binary on macOS (Apple Silicon):
```sh
# see releases page for other versions
export VERSION=0.23.0
export VERSION=0.24.0
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/arm64/toolbox
chmod +x toolbox
```
@@ -125,7 +125,7 @@ To install Toolbox as a binary on macOS (Intel):
```sh
# see releases page for other versions
export VERSION=0.23.0
export VERSION=0.24.0
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/amd64/toolbox
chmod +x toolbox
```
@@ -136,7 +136,7 @@ To install Toolbox as a binary on Windows (Command Prompt):
```cmd
:: see releases page for other versions
set VERSION=0.23.0
set VERSION=0.24.0
curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe"
```
@@ -146,7 +146,7 @@ To install Toolbox as a binary on Windows (PowerShell):
```powershell
# see releases page for other versions
$VERSION = "0.23.0"
$VERSION = "0.24.0"
curl.exe -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v$VERSION/windows/amd64/toolbox.exe"
```
@@ -158,7 +158,7 @@ You can also install Toolbox as a container:
```sh
# see releases page for other versions
export VERSION=0.23.0
export VERSION=0.24.0
docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION
```
@@ -177,7 +177,7 @@ To install from source, ensure you have the latest version of
[Go installed](https://go.dev/doc/install), and then run the following command:
```sh
go install github.com/googleapis/genai-toolbox@v0.23.0
go install github.com/googleapis/genai-toolbox@v0.24.0
```
{{% /tab %}}

View File

@@ -105,7 +105,7 @@ In this section, we will download Toolbox, configure our tools in a
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
```
<!-- {x-release-please-end} -->

View File

@@ -13,7 +13,7 @@ In this section, we will download Toolbox, configure our tools in a
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
```
<!-- {x-release-please-end} -->

View File

@@ -49,19 +49,19 @@ to expose your developer assistant tools to a Looker instance:
{{< tabpane persist=header >}}
{{< tab header="linux/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
{{< /tab >}}
{{< tab header="darwin/arm64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
{{< /tab >}}
{{< tab header="darwin/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
{{< /tab >}}
{{< tab header="windows/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
{{< /tab >}}
{{< /tabpane >}}
<!-- {x-release-please-end} -->

View File

@@ -45,19 +45,19 @@ instance:
<!-- {x-release-please-start-version} -->
{{< tabpane persist=header >}}
{{< tab header="linux/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
{{< /tab >}}
{{< tab header="darwin/arm64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
{{< /tab >}}
{{< tab header="darwin/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
{{< /tab >}}
{{< tab header="windows/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
{{< /tab >}}
{{< /tabpane >}}
<!-- {x-release-please-end} -->

View File

@@ -43,19 +43,19 @@ expose your developer assistant tools to a MySQL instance:
<!-- {x-release-please-start-version} -->
{{< tabpane persist=header >}}
{{< tab header="linux/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
{{< /tab >}}
{{< tab header="darwin/arm64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
{{< /tab >}}
{{< tab header="darwin/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
{{< /tab >}}
{{< tab header="windows/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
{{< /tab >}}
{{< /tabpane >}}
<!-- {x-release-please-end} -->

View File

@@ -44,19 +44,19 @@ expose your developer assistant tools to a Neo4j instance:
<!-- {x-release-please-start-version} -->
{{< tabpane persist=header >}}
{{< tab header="linux/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
{{< /tab >}}
{{< tab header="darwin/arm64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
{{< /tab >}}
{{< tab header="darwin/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
{{< /tab >}}
{{< tab header="windows/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
{{< /tab >}}
{{< /tabpane >}}
<!-- {x-release-please-end} -->

View File

@@ -56,19 +56,19 @@ Omni](https://cloud.google.com/alloydb/omni/current/docs/overview).
<!-- {x-release-please-start-version} -->
{{< tabpane persist=header >}}
{{< tab header="linux/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
{{< /tab >}}
{{< tab header="darwin/arm64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
{{< /tab >}}
{{< tab header="darwin/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
{{< /tab >}}
{{< tab header="windows/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
{{< /tab >}}
{{< /tabpane >}}
<!-- {x-release-please-end} -->

View File

@@ -43,19 +43,19 @@ to expose your developer assistant tools to a SQLite instance:
<!-- {x-release-please-start-version} -->
{{< tabpane persist=header >}}
{{< tab header="linux/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
{{< /tab >}}
{{< tab header="darwin/arm64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
{{< /tab >}}
{{< tab header="darwin/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
{{< /tab >}}
{{< tab header="windows/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
{{< /tab >}}
{{< /tabpane >}}
<!-- {x-release-please-end} -->

View File

@@ -79,12 +79,16 @@ There are a couple of steps to run and use a Collector.
```
1. Run toolbox with the `--telemetry-otlp` flag. Configure it to send them to
`http://127.0.0.1:4553` (for HTTP) or the Collector's URL.
`127.0.0.1:4553` (for HTTP) or the Collector's URL.
```bash
./toolbox --telemetry-otlp=http://127.0.0.1:4553
./toolbox --telemetry-otlp=127.0.0.1:4553
```
{{< notice tip >}}
To pass an insecure endpoint, set environment variable `OTEL_EXPORTER_OTLP_INSECURE=true`.
{{< /notice >}}
1. Once telemetry datas are collected, you can view them in your telemetry
backend. If you are using GCP exporters, telemetry will be visible in GCP
dashboard at [Metrics Explorer][metrics-explorer] and [Trace

View File

@@ -16,14 +16,14 @@ description: >
| | `--log-level` | Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'. | `info` |
| | `--logging-format` | Specify logging format to use. Allowed: 'standard' or 'JSON'. | `standard` |
| `-p` | `--port` | Port the server will listen on. | `5000` |
| | `--prebuilt` | Use a prebuilt tool configuration by source type. Cannot be used with --tools-file. See [Prebuilt Tools Reference](prebuilt-tools.md) for allowed values. | |
| | `--prebuilt` | Use a prebuilt tool configuration by source type. See [Prebuilt Tools Reference](prebuilt-tools.md) for allowed values. | |
| | `--stdio` | Listens via MCP STDIO instead of acting as a remote HTTP server. | |
| | `--telemetry-gcp` | Enable exporting directly to Google Cloud Monitoring. | |
| | `--telemetry-otlp` | Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318') | |
| | `--telemetry-service-name` | Sets the value of the service.name resource attribute for telemetry data. | `toolbox` |
| | `--tools-file` | File path specifying the tool configuration. Cannot be used with --prebuilt, --tools-files, or --tools-folder. | |
| | `--tools-files` | Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --prebuilt, --tools-file, or --tools-folder. | |
| | `--tools-folder` | Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --prebuilt, --tools-file, or --tools-files. | |
| | `--tools-file` | File path specifying the tool configuration. Cannot be used with --tools-files or --tools-folder. | |
| | `--tools-files` | Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --tools-file or --tools-folder. | |
| | `--tools-folder` | Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --tools-file or --tools-files. | |
| | `--ui` | Launches the Toolbox UI web server. | |
| | `--allowed-origins` | Specifies a list of origins permitted to access this server. | `*` |
| `-v` | `--version` | version for toolbox | |
@@ -46,6 +46,9 @@ description: >
```bash
# Basic server with custom port configuration
./toolbox --tools-file "tools.yaml" --port 8080
# Server with prebuilt + custom tools configurations
./toolbox --tools-file tools.yaml --prebuilt alloydb-postgres
```
### Tool Configuration Sources
@@ -72,8 +75,8 @@ The CLI supports multiple mutually exclusive ways to specify tool configurations
{{< notice tip >}}
The CLI enforces mutual exclusivity between configuration source flags,
preventing simultaneous use of `--prebuilt` with file-based options, and
ensuring only one of `--tools-file`, `--tools-files`, or `--tools-folder` is
preventing simultaneous use of the file-based options ensuring only one of
`--tools-file`, `--tools-files`, or `--tools-folder` is
used at a time.
{{< /notice >}}

View File

@@ -13,6 +13,12 @@ allowing developers to interact with and take action on databases.
See guides, [Connect from your IDE](../how-to/connect-ide/_index.md), for
details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP.
{{< notice tip >}}
You can now use `--prebuilt` along `--tools-file`, `--tools-files`, or
`--tools-folder` to combine prebuilt configs with custom tools.
See [Usage Examples](../reference/cli.md#examples).
{{< /notice >}}
## AlloyDB Postgres
* `--prebuilt` value: `alloydb-postgres`

View File

@@ -0,0 +1,40 @@
---
title: "Gemini Data Analytics"
type: docs
weight: 1
description: >
A "cloud-gemini-data-analytics" source provides a client for the Gemini Data Analytics API.
aliases:
- /resources/sources/cloud-gemini-data-analytics
---
## About
The `cloud-gemini-data-analytics` source provides a client to interact with the [Gemini Data Analytics API](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/reference/rest). This allows tools to send natural language queries to the API.
Authentication can be handled in two ways:
1. **Application Default Credentials (ADC) (Recommended):** By default, the source uses ADC to authenticate with the API. The Toolbox server will fetch the credentials from its running environment (server-side authentication). This is the recommended method.
2. **Client-side OAuth:** If `useClientOAuth` is set to `true`, the source expects the authentication token to be provided by the caller when making a request to the Toolbox server (typically via an HTTP Bearer token). The Toolbox server will then forward this token to the underlying Gemini Data Analytics API calls.
## Example
```yaml
sources:
my-gda-source:
kind: cloud-gemini-data-analytics
projectId: my-project-id
my-oauth-gda-source:
kind: cloud-gemini-data-analytics
projectId: my-project-id
useClientOAuth: true
```
## Reference
| **field** | **type** | **required** | **description** |
| -------------- | :------: | :----------: | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| kind | string | true | Must be "cloud-gemini-data-analytics". |
| projectId | string | true | The Google Cloud Project ID where the API is enabled. |
| useClientOAuth | boolean | false | If true, the source uses the token provided by the caller (forwarded to the API). Otherwise, it uses server-side Application Default Credentials (ADC). Defaults to `false`. |

View File

@@ -31,6 +31,9 @@ to a database by following these instructions][csql-mysql-quickstart].
- [`mysql-list-active-queries`](../tools/mysql/mysql-list-active-queries.md)
List active queries in Cloud SQL for MySQL.
- [`mysql-get-query-plan`](../tools/mysql/mysql-get-query-plan.md)
Provide information about how MySQL executes a SQL statement (EXPLAIN).
- [`mysql-list-tables`](../tools/mysql/mysql-list-tables.md)
List tables in a Cloud SQL for MySQL database.

View File

@@ -25,6 +25,9 @@ reliability, performance, and ease of use.
- [`mysql-list-active-queries`](../tools/mysql/mysql-list-active-queries.md)
List active queries in MySQL.
- [`mysql-get-query-plan`](../tools/mysql/mysql-get-query-plan.md)
Provide information about how MySQL executes a SQL statement (EXPLAIN).
- [`mysql-list-tables`](../tools/mysql/mysql-list-tables.md)
List tables in a MySQL database.

View File

@@ -18,10 +18,10 @@ DW) database workloads.
## Available Tools
- [`oracle-sql`](../tools/oracle/oracle-sql.md)
Execute pre-defined prepared SQL queries in Oracle.
Execute pre-defined prepared SQL queries in Oracle.
- [`oracle-execute-sql`](../tools/oracle/oracle-execute-sql.md)
Run parameterized SQL queries in Oracle.
Run parameterized SQL queries in Oracle.
## Requirements
@@ -33,6 +33,25 @@ user][oracle-users] to log in to the database with the necessary permissions.
[oracle-users]:
https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/CREATE-USER.html
### Oracle Driver Requirement (Conditional)
The Oracle source offers two connection drivers:
1. **Pure Go Driver (`useOCI: false`, default):** Uses the `go-ora` library.
This driver is simpler and does not require any local Oracle software
installation, but it **lacks support for advanced features** like Oracle
Wallets or Kerberos authentication.
2. **OCI-Based Driver (`useOCI: true`):** Uses the `godror` library, which
provides access to **advanced Oracle features** like Digital Wallet support.
If you set `useOCI: true`, you **must** install the **Oracle Instant Client**
libraries on the machine where this tool runs.
You can download the Instant Client from the official Oracle website: [Oracle
Instant Client
Downloads](https://www.oracle.com/database/technologies/instant-client/downloads.html)
## Connection Methods
You can configure the connection to your Oracle database using one of the
@@ -66,12 +85,15 @@ using a TNS (Transparent Network Substrate) alias.
containing it. This setting will override the `TNS_ADMIN` environment
variable.
## Example
## Examples
This example demonstrates the four connection methods you could choose from:
```yaml
sources:
my-oracle-source:
kind: oracle
# --- Choose one connection method ---
# 1. Host, Port, and Service Name
host: 127.0.0.1
@@ -88,6 +110,43 @@ sources:
user: ${USER_NAME}
password: ${PASSWORD}
# Optional: Set to true to use the OCI-based driver for advanced features (Requires Oracle Instant Client)
```
### Using an Oracle Wallet
Oracle Wallet allows you to store credentails used for database connection. Depending whether you are using an OCI-based driver, the wallet configuration is different.
#### Pure Go Driver (`useOCI: false`) - Oracle Wallet
The `go-ora` driver uses the `walletLocation` field to connect to a database secured with an Oracle Wallet without standard username and password.
```yaml
sources:
pure-go-wallet:
kind: oracle
connectionString: "127.0.0.1:1521/XEPDB1"
user: ${USER_NAME}
password: ${PASSWORD}
# The TNS Alias is often required to connect to a service registered in tnsnames.ora
tnsAlias: "SECURE_DB_ALIAS"
walletLocation: "/path/to/my/wallet/directory"
```
#### OCI-Based Driver (`useOCI: true`) - Oracle Wallet
For the OCI-based driver, wallet authentication is triggered by setting tnsAdmin to the wallet directory and connecting via a tnsAlias.
```yaml
sources:
oci-wallet:
kind: oracle
connectionString: "127.0.0.1:1521/XEPDB1"
user: ${USER_NAME}
password: ${PASSWORD}
tnsAlias: "WALLET_DB_ALIAS"
tnsAdmin: "/opt/oracle/wallet" # Directory containing tnsnames.ora, sqlnet.ora, and wallet files
useOCI: true
```
{{< notice tip >}}
@@ -97,14 +156,15 @@ instead of hardcoding your secrets into the configuration file.
## Reference
| **field** | **type** | **required** | **description** |
|------------------|:--------:|:------------:|-----------------------------------------------------------------------------------------------------------------------------|
| kind | string | true | Must be "oracle". |
| user | string | true | Name of the Oracle user to connect as (e.g. "my-oracle-user"). |
| password | string | true | Password of the Oracle user (e.g. "my-password"). |
| host | string | false | IP address or hostname to connect to (e.g. "127.0.0.1"). Required if not using `connectionString` or `tnsAlias`. |
| port | integer | false | Port to connect to (e.g. "1521"). Required if not using `connectionString` or `tnsAlias`. |
| serviceName | string | false | The Oracle service name of the database to connect to. Required if not using `connectionString` or `tnsAlias`. |
| connectionString | string | false | A direct connection string (e.g. "hostname:port/servicename"). Use as an alternative to `host`, `port`, and `serviceName`. |
| tnsAlias | string | false | A TNS alias from a `tnsnames.ora` file. Use as an alternative to `host`/`port` or `connectionString`. |
| tnsAdmin | string | false | Path to the directory containing the `tnsnames.ora` file. This overrides the `TNS_ADMIN` environment variable if it is set. |
| **field** | **type** | **required** | **description** |
|------------------|:--------:|:------------:|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| kind | string | true | Must be "oracle". |
| user | string | true | Name of the Oracle user to connect as (e.g. "my-oracle-user"). |
| password | string | true | Password of the Oracle user (e.g. "my-password"). |
| host | string | false | IP address or hostname to connect to (e.g. "127.0.0.1"). Required if not using `connectionString` or `tnsAlias`. |
| port | integer | false | Port to connect to (e.g. "1521"). Required if not using `connectionString` or `tnsAlias`. |
| serviceName | string | false | The Oracle service name of the database to connect to. Required if not using `connectionString` or `tnsAlias`. |
| connectionString | string | false | A direct connection string (e.g. "hostname:port/servicename"). Use as an alternative to `host`, `port`, and `serviceName`. |
| tnsAlias | string | false | A TNS alias from a `tnsnames.ora` file. Use as an alternative to `host`/`port` or `connectionString`. |
| tnsAdmin | string | false | Path to the directory containing the `tnsnames.ora` file. This overrides the `TNS_ADMIN` environment variable if it is set. |
| useOCI | bool | false | If true, uses the OCI-based driver (godror) which supports Oracle Wallet/Kerberos but requires the Oracle Instant Client libraries to be installed. Defaults to false (pure Go driver). |

View File

@@ -0,0 +1,7 @@
---
title: "Gemini Data Analytics"
type: docs
weight: 1
description: >
Tools for Gemini Data Analytics.
---

View File

@@ -0,0 +1,92 @@
---
title: "Gemini Data Analytics QueryData"
type: docs
weight: 1
description: >
A tool to convert natural language queries into SQL statements using the Gemini Data Analytics QueryData API.
aliases:
- /resources/tools/cloud-gemini-data-analytics-query
---
## About
The `cloud-gemini-data-analytics-query` tool allows you to send natural language questions to the Gemini Data Analytics API and receive structured responses containing SQL queries, natural language answers, and explanations. For details on defining data agent context for database data sources, see the official [documentation](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/data-agent-authored-context-databases).
## Example
```yaml
tools:
my-gda-query-tool:
kind: cloud-gemini-data-analytics-query
source: my-gda-source
description: "Use this tool to send natural language queries to the Gemini Data Analytics API and receive SQL, natural language answers, and explanations."
location: ${your_database_location}
context:
datasourceReferences:
cloudSqlReference:
databaseReference:
projectId: "${your_project_id}"
region: "${your_database_instance_region}"
instanceId: "${your_database_instance_id}"
databaseId: "${your_database_name}"
engine: "POSTGRESQL"
agentContextReference:
contextSetId: "${your_context_set_id}" # E.g. projects/${project_id}/locations/${context_set_location}/contextSets/${context_set_id}
generationOptions:
generateQueryResult: true
generateNaturalLanguageAnswer: true
generateExplanation: true
generateDisambiguationQuestion: true
```
### Usage Flow
When using this tool, a `prompt` parameter containing a natural language query is provided to the tool (typically by an agent). The tool then interacts with the Gemini Data Analytics API using the context defined in your configuration.
The structure of the response depends on the `generationOptions` configured in your tool definition (e.g., enabling `generateQueryResult` will include the SQL query results).
See [Data Analytics API REST documentation](https://clouddocs.devsite.corp.google.com/gemini/docs/conversational-analytics-api/reference/rest/v1alpha/projects.locations/queryData?rep_location=global) for details.
**Example Input Prompt:**
```text
How many accounts who have region in Prague are eligible for loans? A3 contains the data of region.
```
**Example API Response:**
```json
{
"generatedQuery": "SELECT COUNT(T1.account_id) FROM account AS T1 INNER JOIN loan AS T2 ON T1.account_id = T2.account_id INNER JOIN district AS T3 ON T1.district_id = T3.district_id WHERE T3.A3 = 'Prague'",
"intentExplanation": "I found a template that matches the user's question. The template asks about the number of accounts who have region in a given city and are eligible for loans. The question asks about the number of accounts who have region in Prague and are eligible for loans. The template's parameterized SQL is 'SELECT COUNT(T1.account_id) FROM account AS T1 INNER JOIN loan AS T2 ON T1.account_id = T2.account_id INNER JOIN district AS T3 ON T1.district_id = T3.district_id WHERE T3.A3 = ?'. I will replace the named parameter '?' with 'Prague'.",
"naturalLanguageAnswer": "There are 84 accounts from the Prague region that are eligible for loans.",
"queryResult": {
"columns": [
{
"type": "INT64"
}
],
"rows": [
{
"values": [
{
"value": "84"
}
]
}
],
"totalRowCount": "1"
}
}
```
## Reference
| **field** | **type** | **required** | **description** |
| ----------------- | :------: | :----------: | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| kind | string | true | Must be "cloud-gemini-data-analytics-query". |
| source | string | true | The name of the `cloud-gemini-data-analytics` source to use. |
| description | string | true | A description of the tool's purpose. |
| location | string | true | The Google Cloud location of the target database resource (e.g., "us-central1"). This is used to construct the parent resource name in the API call. |
| context | object | true | The context for the query, including datasource references. See [QueryDataContext](https://github.com/googleapis/googleapis/blob/b32495a713a68dd0dff90cf0b24021debfca048a/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto#L156) for details. |
| generationOptions | object | false | Options for generating the response. See [GenerationOptions](https://github.com/googleapis/googleapis/blob/b32495a713a68dd0dff90cf0b24021debfca048a/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto#L135) for details. |

View File

@@ -0,0 +1,39 @@
---
title: "mysql-get-query-plan"
type: docs
weight: 1
description: >
A "mysql-get-query-plan" tool gets the execution plan for a SQL statement against a MySQL
database.
aliases:
- /resources/tools/mysql-get-query-plan
---
## About
A `mysql-get-query-plan` tool gets the execution plan for a SQL statement against a MySQL
database. It's compatible with any of the following sources:
- [cloud-sql-mysql](../../sources/cloud-sql-mysql.md)
- [mysql](../../sources/mysql.md)
`mysql-get-query-plan` takes one input parameter `sql_statement` and gets the execution plan for the SQL
statement against the `source`.
## Example
```yaml
tools:
get_query_plan_tool:
kind: mysql-get-query-plan
source: my-mysql-instance
description: Use this tool to get the execution plan for a sql statement.
```
## Reference
| **field** | **type** | **required** | **description** |
|-------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------|
| kind | string | true | Must be "mysql-get-query-plan". |
| source | string | true | Name of the source the SQL should execute on. |
| description | string | true | Description of the tool that is passed to the LLM. |

View File

@@ -771,7 +771,7 @@
},
"outputs": [],
"source": [
"version = \"0.23.0\" # x-release-please-version\n",
"version = \"0.24.0\" # x-release-please-version\n",
"! curl -L -o /content/toolbox https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
"\n",
"# Make the binary executable\n",

View File

@@ -123,7 +123,7 @@ In this section, we will download and install the Toolbox binary.
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
export VERSION="0.23.0"
export VERSION="0.24.0"
curl -O https://storage.googleapis.com/genai-toolbox/v$VERSION/$OS/toolbox
```
<!-- {x-release-please-end} -->

View File

@@ -220,7 +220,7 @@
},
"outputs": [],
"source": [
"version = \"0.23.0\" # x-release-please-version\n",
"version = \"0.24.0\" # x-release-please-version\n",
"! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
"\n",
"# Make the binary executable\n",

View File

@@ -179,7 +179,7 @@ to use BigQuery, and then run the Toolbox server.
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
```
<!-- {x-release-please-end} -->

View File

@@ -98,7 +98,7 @@ In this section, we will download Toolbox, configure our tools in a
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
```
<!-- {x-release-please-end} -->

View File

@@ -34,7 +34,7 @@ In this section, we will download Toolbox and run the Toolbox server.
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
```
<!-- {x-release-please-end} -->

View File

@@ -48,7 +48,7 @@ In this section, we will download Toolbox and run the Toolbox server.
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
```
<!-- {x-release-please-end} -->

View File

@@ -34,7 +34,7 @@ In this section, we will download Toolbox and run the Toolbox server.
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
```
<!-- {x-release-please-end} -->

View File

@@ -1,6 +1,6 @@
{
"name": "mcp-toolbox-for-databases",
"version": "0.23.0",
"version": "0.24.0",
"description": "MCP Toolbox for Databases is an open-source MCP server for more than 30 different datasources.",
"contextFileName": "MCP-TOOLBOX-EXTENSION.md"
}

12
go.mod
View File

@@ -12,7 +12,7 @@ require (
cloud.google.com/go/dataplex v1.28.0
cloud.google.com/go/dataproc/v2 v2.15.0
cloud.google.com/go/firestore v1.20.0
cloud.google.com/go/geminidataanalytics v0.2.1
cloud.google.com/go/geminidataanalytics v0.3.0
cloud.google.com/go/longrunning v0.7.0
cloud.google.com/go/spanner v1.86.1
github.com/ClickHouse/clickhouse-go/v2 v2.40.3
@@ -22,7 +22,7 @@ require (
github.com/cenkalti/backoff/v5 v5.0.3
github.com/couchbase/gocb/v2 v2.11.1
github.com/couchbase/tools-common/http v1.0.9
github.com/elastic/elastic-transport-go/v8 v8.7.0
github.com/elastic/elastic-transport-go/v8 v8.8.0
github.com/elastic/go-elasticsearch/v9 v9.2.0
github.com/fsnotify/fsnotify v1.9.0
github.com/go-chi/chi/v5 v5.2.3
@@ -33,6 +33,7 @@ require (
github.com/go-playground/validator/v10 v10.28.0
github.com/go-sql-driver/mysql v1.9.3
github.com/goccy/go-yaml v1.18.0
github.com/godror/godror v0.49.6
github.com/google/go-cmp v0.7.0
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.7.6
@@ -41,7 +42,7 @@ require (
github.com/microsoft/go-mssqldb v1.9.3
github.com/nakagami/firebirdsql v0.9.15
github.com/neo4j/neo4j-go-driver/v5 v5.28.4
github.com/redis/go-redis/v9 v9.16.0
github.com/redis/go-redis/v9 v9.17.2
github.com/sijms/go-ora/v2 v2.9.0
github.com/spf13/cobra v1.10.1
github.com/thlib/go-timezone-local v0.0.7
@@ -91,6 +92,7 @@ require (
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.29.0 // indirect
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.54.0 // indirect
github.com/PuerkitoBio/goquery v1.10.3 // indirect
github.com/VictoriaMetrics/easyproto v0.1.4 // indirect
github.com/ajg/form v1.5.1 // indirect
github.com/apache/arrow/go/v15 v15.0.2 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
@@ -107,11 +109,13 @@ require (
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/gabriel-vasile/mimetype v1.4.10 // indirect
github.com/go-jose/go-jose/v4 v4.1.2 // indirect
github.com/go-logfmt/logfmt v0.6.0 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/godror/knownpb v0.3.0 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
@@ -181,7 +185,7 @@ require (
golang.org/x/time v0.14.0 // indirect
golang.org/x/tools v0.38.0 // indirect
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 // indirect
google.golang.org/grpc v1.76.0 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect

30
go.sum
View File

@@ -311,8 +311,8 @@ cloud.google.com/go/gaming v1.6.0/go.mod h1:YMU1GEvA39Qt3zWGyAVA9bpYz/yAhTvaQ1t2
cloud.google.com/go/gaming v1.7.0/go.mod h1:LrB8U7MHdGgFG851iHAfqUdLcKBdQ55hzXy9xBJz0+w=
cloud.google.com/go/gaming v1.8.0/go.mod h1:xAqjS8b7jAVW0KFYeRUxngo9My3f33kFmua++Pi+ggM=
cloud.google.com/go/gaming v1.9.0/go.mod h1:Fc7kEmCObylSWLO334NcO+O9QMDyz+TKC4v1D7X+Bc0=
cloud.google.com/go/geminidataanalytics v0.2.1 h1:gtG/9VlUJpL67yukFen/twkAEHliYvW7610Rlnn5rpQ=
cloud.google.com/go/geminidataanalytics v0.2.1/go.mod h1:gIsj/ELDCzVbw24185zwjXgbzYiqdGe7TSSK2HrdtA0=
cloud.google.com/go/geminidataanalytics v0.3.0 h1:2Wi/kqFb5OLuEGH7q+/miE19VTqK1MYHjBEHENap9HI=
cloud.google.com/go/geminidataanalytics v0.3.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
cloud.google.com/go/gkebackup v0.2.0/go.mod h1:XKvv/4LfG829/B8B7xRkk8zRrOEbKtEam6yNfuQNH60=
cloud.google.com/go/gkebackup v0.3.0/go.mod h1:n/E671i1aOQvUxT541aTkCwExO/bTer2HDlj4TsBRAo=
cloud.google.com/go/gkebackup v0.4.0/go.mod h1:byAyBGUwYGEEww7xsbnUTBHIYcOPy/PgUWUtOeRm9Vg=
@@ -683,6 +683,10 @@ github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/PuerkitoBio/goquery v1.10.3 h1:pFYcNSqHxBD06Fpj/KsbStFRsgRATgnf3LeXiUkhzPo=
github.com/PuerkitoBio/goquery v1.10.3/go.mod h1:tMUX0zDMHXYlAQk6p35XxQMqMweEKB7iK7iLNd4RH4Y=
github.com/UNO-SOFT/zlog v0.8.1 h1:TEFkGJHtUfTRgMkLZiAjLSHALjwSBdw6/zByMC5GJt4=
github.com/UNO-SOFT/zlog v0.8.1/go.mod h1:yqFOjn3OhvJ4j7ArJqQNA+9V+u6t9zSAyIZdWdMweWc=
github.com/VictoriaMetrics/easyproto v0.1.4 h1:r8cNvo8o6sR4QShBXQd1bKw/VVLSQma/V2KhTBPf+Sc=
github.com/VictoriaMetrics/easyproto v0.1.4/go.mod h1:QlGlzaJnDfFd8Lk6Ci/fuLxfTo3/GThPs2KH23mv710=
github.com/ahmetb/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:3YVZUqkoev4mL+aCwVOSWV4M7pN+NURHL38Z2zq5JKA=
github.com/ahmetb/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:ymXt5bw5uSNu4jveerFxE0vNYxF8ncqbptntMaFMg3k=
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
@@ -818,8 +822,8 @@ github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/elastic/elastic-transport-go/v8 v8.7.0 h1:OgTneVuXP2uip4BA658Xi6Hfw+PeIOod2rY3GVMGoVE=
github.com/elastic/elastic-transport-go/v8 v8.7.0/go.mod h1:YLHer5cj0csTzNFXoNQ8qhtGY1GTvSqPnKWKaqQE3Hk=
github.com/elastic/elastic-transport-go/v8 v8.8.0 h1:7k1Ua+qluFr6p1jfJjGDl97ssJS/P7cHNInzfxgBQAo=
github.com/elastic/elastic-transport-go/v8 v8.8.0/go.mod h1:YLHer5cj0csTzNFXoNQ8qhtGY1GTvSqPnKWKaqQE3Hk=
github.com/elastic/go-elasticsearch/v9 v9.2.0 h1:COeL/g20+ixnUbffe4Wfbu88emrHjAq/LhVfmrjqRQs=
github.com/elastic/go-elasticsearch/v9 v9.2.0/go.mod h1:2PB5YQPpY5tWbF65MRqzEXA31PZOdXCkloQSOZtU14I=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
@@ -884,6 +888,8 @@ github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vb
github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U=
github.com/go-latex/latex v0.0.0-20210823091927-c0d11ff05a81/go.mod h1:SX0U8uGpxhq9o2S/CELCSUxEWWAuoCUcVCQWv7G2OCk=
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
@@ -909,6 +915,10 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/godror/godror v0.49.6 h1:ts4ZGw8uLJ42e1D7aXmVuSrld0/lzUzmIUjuUuQOgGM=
github.com/godror/godror v0.49.6/go.mod h1:kTMcxZzRw73RT5kn9v3JkBK4kHI6dqowHotqV72ebU8=
github.com/godror/knownpb v0.3.0 h1:+caUdy8hTtl7X05aPl3tdL540TvCcaQA6woZQroLZMw=
github.com/godror/knownpb v0.3.0/go.mod h1:PpTyfJwiOEAzQl7NtVCM8kdPCnp3uhxsZYIzZ5PV4zU=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
@@ -1172,6 +1182,8 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/neo4j/neo4j-go-driver/v5 v5.28.4 h1:7toxehVcYkZbyxV4W3Ib9VcnyRBQPucF+VwNNmtSXi4=
github.com/neo4j/neo4j-go-driver/v5 v5.28.4/go.mod h1:Vff8OwT7QpLm7L2yYr85XNWe9Rbqlbeb9asNXJTHO4k=
github.com/oklog/ulid/v2 v2.0.2 h1:r4fFzBm+bv0wNKNh5eXTwU7i85y5x+uwkxCUTNVQqLc=
github.com/oklog/ulid/v2 v2.0.2/go.mod h1:mtBL0Qe/0HAx6/a4Z30qxVIAL1eQDweXq5lxOEiwQ68=
github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8=
github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
@@ -1210,8 +1222,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.3.0/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w=
github.com/redis/go-redis/v9 v9.16.0 h1:OotgqgLSRCmzfqChbQyG1PHC3tLNR89DG4jdOERSEP4=
github.com/redis/go-redis/v9 v9.16.0/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
@@ -1671,6 +1683,8 @@ golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -1990,8 +2004,8 @@ google.golang.org/genproto v0.0.0-20230331144136-dcfb400f0633/go.mod h1:UUQDJDOl
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU=
google.golang.org/genproto v0.0.0-20251022142026-3a174f9686a8 h1:a12a2/BiVRxRWIqBbfqoSK6tgq8cyUgMnEI81QlPge0=
google.golang.org/genproto v0.0.0-20251022142026-3a174f9686a8/go.mod h1:1Ic78BnpzY8OaTCmzxJDP4qC9INZPbGZl+54RKjtyeI=
google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f h1:OiFuztEyBivVKDvguQJYWq1yDcfAHIID/FVrPR4oiI0=
google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f/go.mod h1:kprOiu9Tr0JYyD6DORrc4Hfyk3RFXqkQ3ctHEum3ZbM=
google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba h1:B14OtaXuMaCQsl2deSvNkyPKIzq3BjfxQp8d00QyWx4=
google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba/go.mod h1:G5IanEx8/PgI9w6CFcYQf7jMtHQhZruvfM1i3qOqk5U=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 h1:tRPGkdGHuewF4UisLzzHHr1spKw92qLM98nIzxbC0wY=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=

View File

@@ -32,16 +32,9 @@ tools:
source: cloud-sql-mysql-source
description: Lists top N (default 10) ongoing queries from processlist and innodb_trx, ordered by execution time in descending order. Returns detailed information of those queries in json format, including process id, query, transaction duration, transaction wait duration, process time, transaction state, process state, username with host, transaction rows locked, transaction rows modified, and db schema.
get_query_plan:
kind: mysql-sql
kind: mysql-get-query-plan
source: cloud-sql-mysql-source
description: "Provide information about how MySQL executes a SQL statement. Common use cases include: 1) analyze query plan to improve its performance, and 2) determine effectiveness of existing indexes and evalueate new ones."
statement: |
EXPLAIN FORMAT=JSON {{.sql_statement}};
templateParameters:
- name: sql_statement
type: string
description: "the SQL statement to explain"
required: true
list_tables:
kind: mysql-list-tables
source: cloud-sql-mysql-source

View File

@@ -36,16 +36,9 @@ tools:
source: mysql-source
description: Lists top N (default 10) ongoing queries from processlist and innodb_trx, ordered by execution time in descending order. Returns detailed information of those queries in json format, including process id, query, transaction duration, transaction wait duration, process time, transaction state, process state, username with host, transaction rows locked, transaction rows modified, and db schema.
get_query_plan:
kind: mysql-sql
kind: mysql-get-query-plan
source: mysql-source
description: "Provide information about how MySQL executes a SQL statement. Common use cases include: 1) analyze query plan to improve its performance, and 2) determine effectiveness of existing indexes and evalueate new ones."
statement: |
EXPLAIN FORMAT=JSON {{.sql_statement}};
templateParameters:
- name: sql_statement
type: string
description: "the SQL statement to explain"
required: true
list_tables:
kind: mysql-list-tables
source: mysql-source

View File

@@ -172,7 +172,14 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
accessToken := tools.AccessToken(r.Header.Get("Authorization"))
// Check if this specific tool requires the standard authorization header
if tool.RequiresClientAuthorization(s.ResourceMgr) {
clientAuth, err := tool.RequiresClientAuthorization(s.ResourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
s.logger.DebugContext(ctx, errMsg.Error())
_ = render.Render(w, r, newErrResponse(errMsg, http.StatusNotFound))
return
}
if clientAuth {
if accessToken == "" {
err = fmt.Errorf("tool requires client authorization but access token is missing from the request header")
s.logger.DebugContext(ctx, err.Error())
@@ -255,7 +262,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
}
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
if tool.RequiresClientAuthorization(s.ResourceMgr) {
if clientAuth {
// Propagate the original 401/403 error.
s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err))
_ = render.Render(w, r, newErrResponse(err, statusCode))

View File

@@ -77,9 +77,9 @@ func (t MockTool) Authorized(verifiedAuthServices []string) bool {
return !t.unauthorized
}
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) bool {
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) (bool, error) {
// defaulted to false
return t.requiresClientAuthrorization
return t.requiresClientAuthrorization, nil
}
func (t MockTool) McpManifest() tools.McpManifest {
@@ -119,8 +119,8 @@ func (t MockTool) McpManifest() tools.McpManifest {
return mcpManifest
}
func (t MockTool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t MockTool) GetAuthTokenHeaderName(tools.SourceProvider) (string, error) {
return "Authorization", nil
}
// MockPrompt is used to mock prompts in tests

View File

@@ -205,10 +205,13 @@ func (s *stdioSession) readLine(ctx context.Context) (string, error) {
}
// write writes to stdout with response to client
func (s *stdioSession) write(ctx context.Context, response any) error {
res, _ := json.Marshal(response)
func (s *stdioSession) write(_ context.Context, response any) error {
res, err := json.Marshal(response)
if err != nil {
return fmt.Errorf("failed to marshal response to JSON: %w", err)
}
_, err := fmt.Fprintf(s.writer, "%s\n", res)
_, err = fmt.Fprintf(s.writer, "%s\n", res)
return err
}

View File

@@ -108,10 +108,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Get access token
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
// Check if this specific tool requires the standard authorization header
if tool.RequiresClientAuthorization(resourceMgr) {
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
if clientAuth {
if accessToken == "" {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
}
@@ -183,7 +193,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if tool.RequiresClientAuthorization(resourceMgr) {
if clientAuth {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}

View File

@@ -108,10 +108,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Get access token
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
// Check if this specific tool requires the standard authorization header
if tool.RequiresClientAuthorization(resourceMgr) {
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
if clientAuth {
if accessToken == "" {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
}
@@ -183,7 +193,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if tool.RequiresClientAuthorization(resourceMgr) {
if clientAuth {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}

View File

@@ -101,10 +101,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Get access token
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
// Check if this specific tool requires the standard authorization header
if tool.RequiresClientAuthorization(resourceMgr) {
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
if clientAuth {
if accessToken == "" {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
}
@@ -176,7 +186,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if tool.RequiresClientAuthorization(resourceMgr) {
if clientAuth {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}

View File

@@ -30,26 +30,6 @@ import (
const SourceKind string = "alloydb-admin"
type userAgentRoundTripper struct {
userAgent string
next http.RoundTripper
}
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
newReq.Header = make(http.Header)
for k, v := range req.Header {
newReq.Header[k] = v
}
ua := newReq.Header.Get("User-Agent")
if ua == "" {
newReq.Header.Set("User-Agent", rt.userAgent)
} else {
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
}
return rt.next.RoundTrip(&newReq)
}
// validate interface
var _ sources.SourceConfig = Config{}
@@ -87,10 +67,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
var client *http.Client
if r.UseClientOAuth {
client = &http.Client{
Transport: &userAgentRoundTripper{
userAgent: ua,
next: http.DefaultTransport,
},
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
}
} else {
// Use Application Default Credentials
@@ -99,10 +76,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("failed to find default credentials: %w", err)
}
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
baseClient.Transport = &userAgentRoundTripper{
userAgent: ua,
next: baseClient.Transport,
}
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
client = baseClient
}
@@ -136,6 +110,10 @@ func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) GetDefaultProject() string {
return s.DefaultProject
}
func (s *Source) GetService(ctx context.Context, accessToken string) (*alloydbrestapi.Service, error) {
if s.UseClientOAuth {
token := &oauth2.Token{AccessToken: accessToken}

View File

@@ -0,0 +1,133 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudgda
import (
"context"
"fmt"
"net/http"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
const SourceKind string = "cloud-gemini-data-analytics"
const Endpoint string = "https://geminidataanalytics.googleapis.com"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
return actual, nil
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
ProjectID string `yaml:"projectId" validate:"required"`
UseClientOAuth bool `yaml:"useClientOAuth"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
}
// Initialize initializes a Gemini Data Analytics Source instance.
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
ua, err := util.UserAgentFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("error in User Agent retrieval: %s", err)
}
var client *http.Client
if r.UseClientOAuth {
client = &http.Client{
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
}
} else {
// Use Application Default Credentials
// Scope: "https://www.googleapis.com/auth/cloud-platform" is generally sufficient for GDA
creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("failed to find default credentials: %w", err)
}
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
client = baseClient
}
s := &Source{
Config: r,
Client: client,
BaseURL: Endpoint,
userAgent: ua,
}
return s, nil
}
var _ sources.Source = &Source{}
type Source struct {
Config
Client *http.Client
BaseURL string
userAgent string
}
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) GetProjectID() string {
return s.ProjectID
}
func (s *Source) GetBaseURL() string {
return s.BaseURL
}
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
if s.UseClientOAuth {
if accessToken == "" {
return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
}
token := &oauth2.Token{AccessToken: accessToken}
baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
baseClient.Transport = util.NewUserAgentRoundTripper(s.userAgent, baseClient.Transport)
return baseClient, nil
}
return s.Client, nil
}
func (s *Source) UseClientAuthorization() bool {
return s.UseClientOAuth
}

View File

@@ -0,0 +1,213 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudgda_test
import (
"context"
"os"
"path/filepath"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
"github.com/googleapis/genai-toolbox/internal/testutils"
"go.opentelemetry.io/otel/trace/noop"
)
func TestParseFromYamlCloudGDA(t *testing.T) {
t.Parallel()
tcs := []struct {
desc string
in string
want server.SourceConfigs
}{
{
desc: "basic example",
in: `
sources:
my-gda-instance:
kind: cloud-gemini-data-analytics
projectId: test-project-id
`,
want: map[string]sources.SourceConfig{
"my-gda-instance": cloudgda.Config{
Name: "my-gda-instance",
Kind: cloudgda.SourceKind,
ProjectID: "test-project-id",
UseClientOAuth: false,
},
},
},
{
desc: "use client auth example",
in: `
sources:
my-gda-instance:
kind: cloud-gemini-data-analytics
projectId: another-project
useClientOAuth: true
`,
want: map[string]sources.SourceConfig{
"my-gda-instance": cloudgda.Config{
Name: "my-gda-instance",
Kind: cloudgda.SourceKind,
ProjectID: "another-project",
UseClientOAuth: true,
},
},
},
}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
}
func TestFailParseFromYaml(t *testing.T) {
t.Parallel()
tcs := []struct {
desc string
in string
err string
}{
{
desc: "missing projectId",
in: `
sources:
my-gda-instance:
kind: cloud-gemini-data-analytics
`,
err: "unable to parse source \"my-gda-instance\" as \"cloud-gemini-data-analytics\": Key: 'Config.ProjectID' Error:Field validation for 'ProjectID' failed on the 'required' tag",
},
}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}
errStr := err.Error()
if errStr != tc.err {
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
}
})
}
}
func TestInitialize(t *testing.T) {
// Create a dummy credentials file for testing ADC
credFile := filepath.Join(t.TempDir(), "application_default_credentials.json")
dummyCreds := `{
"client_id": "foo",
"client_secret": "bar",
"refresh_token": "baz",
"type": "authorized_user"
}`
if err := os.WriteFile(credFile, []byte(dummyCreds), 0644); err != nil {
t.Fatalf("failed to write dummy credentials file: %v", err)
}
t.Setenv("GOOGLE_APPLICATION_CREDENTIALS", credFile)
// Use ContextWithUserAgent to avoid "unable to retrieve user agent" error
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
tracer := noop.NewTracerProvider().Tracer("test")
tcs := []struct {
desc string
cfg cloudgda.Config
wantClientOAuth bool
}{
{
desc: "initialize with ADC",
cfg: cloudgda.Config{Name: "test-gda", Kind: cloudgda.SourceKind, ProjectID: "test-proj"},
wantClientOAuth: false,
},
{
desc: "initialize with client OAuth",
cfg: cloudgda.Config{Name: "test-gda-oauth", Kind: cloudgda.SourceKind, ProjectID: "test-proj", UseClientOAuth: true},
wantClientOAuth: true,
},
}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
src, err := tc.cfg.Initialize(ctx, tracer)
if err != nil {
t.Fatalf("failed to initialize source: %v", err)
}
gdaSrc, ok := src.(*cloudgda.Source)
if !ok {
t.Fatalf("expected *cloudgda.Source, got %T", src)
}
// Check that the client is non-nil
if gdaSrc.Client == nil && !tc.wantClientOAuth {
t.Fatal("expected non-nil HTTP client for ADC, got nil")
}
// When client OAuth is true, the source's client should be initialized with a base HTTP client
// that includes the user agent round tripper, but not the OAuth token. The token-aware
// client is created by GetClient.
if gdaSrc.Client == nil && tc.wantClientOAuth {
t.Fatal("expected non-nil HTTP client for client OAuth config, got nil")
}
// Test UseClientAuthorization method
if gdaSrc.UseClientAuthorization() != tc.wantClientOAuth {
t.Errorf("UseClientAuthorization mismatch: want %t, got %t", tc.wantClientOAuth, gdaSrc.UseClientAuthorization())
}
// Test GetClient with accessToken for client OAuth scenarios
if tc.wantClientOAuth {
client, err := gdaSrc.GetClient(ctx, "dummy-token")
if err != nil {
t.Fatalf("GetClient with token failed: %v", err)
}
if client == nil {
t.Fatal("expected non-nil HTTP client from GetClient with token, got nil")
}
// Ensure passing empty token with UseClientOAuth enabled returns error
_, err = gdaSrc.GetClient(ctx, "")
if err == nil || err.Error() != "client-side OAuth is enabled but no access token was provided" {
t.Errorf("expected 'client-side OAuth is enabled but no access token was provided' error, got: %v", err)
}
}
})
}
}

View File

@@ -29,26 +29,6 @@ import (
const SourceKind string = "cloud-monitoring"
type userAgentRoundTripper struct {
userAgent string
next http.RoundTripper
}
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
newReq.Header = make(http.Header)
for k, v := range req.Header {
newReq.Header[k] = v
}
ua := newReq.Header.Get("User-Agent")
if ua == "" {
newReq.Header.Set("User-Agent", rt.userAgent)
} else {
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
}
return rt.next.RoundTrip(&newReq)
}
// validate interface
var _ sources.SourceConfig = Config{}
@@ -86,10 +66,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
var client *http.Client
if r.UseClientOAuth {
client = &http.Client{
Transport: &userAgentRoundTripper{
userAgent: ua,
next: http.DefaultTransport,
},
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
}
} else {
// Use Application Default Credentials
@@ -98,18 +75,15 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("failed to find default credentials: %w", err)
}
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
baseClient.Transport = &userAgentRoundTripper{
userAgent: ua,
next: baseClient.Transport,
}
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
client = baseClient
}
s := &Source{
Config: r,
BaseURL: "https://monitoring.googleapis.com",
Client: client,
UserAgent: ua,
baseURL: "https://monitoring.googleapis.com",
client: client,
userAgent: ua,
}
return s, nil
}
@@ -118,9 +92,9 @@ var _ sources.Source = &Source{}
type Source struct {
Config
BaseURL string `yaml:"baseUrl"`
Client *http.Client
UserAgent string
baseURL string
client *http.Client
userAgent string
}
func (s *Source) SourceKind() string {
@@ -131,6 +105,18 @@ func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) BaseURL() string {
return s.baseURL
}
func (s *Source) Client() *http.Client {
return s.client
}
func (s *Source) UserAgent() string {
return s.userAgent
}
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
if s.UseClientOAuth {
if accessToken == "" {
@@ -139,7 +125,7 @@ func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Clien
token := &oauth2.Token{AccessToken: accessToken}
return oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)), nil
}
return s.Client, nil
return s.client, nil
}
func (s *Source) UseClientAuthorization() bool {

View File

@@ -30,26 +30,6 @@ import (
const SourceKind string = "cloud-sql-admin"
type userAgentRoundTripper struct {
userAgent string
next http.RoundTripper
}
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
newReq.Header = make(http.Header)
for k, v := range req.Header {
newReq.Header[k] = v
}
ua := newReq.Header.Get("User-Agent")
if ua == "" {
newReq.Header.Set("User-Agent", rt.userAgent)
} else {
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
}
return rt.next.RoundTrip(&newReq)
}
// validate interface
var _ sources.SourceConfig = Config{}
@@ -88,10 +68,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
var client *http.Client
if r.UseClientOAuth {
client = &http.Client{
Transport: &userAgentRoundTripper{
userAgent: ua,
next: http.DefaultTransport,
},
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
}
} else {
// Use Application Default Credentials
@@ -100,10 +77,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("failed to find default credentials: %w", err)
}
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
baseClient.Transport = &userAgentRoundTripper{
userAgent: ua,
next: baseClient.Transport,
}
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
client = baseClient
}
@@ -136,6 +110,10 @@ func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) GetDefaultProject() string {
return s.DefaultProject
}
func (s *Source) GetService(ctx context.Context, accessToken string) (*sqladmin.Service, error) {
if s.UseClientOAuth {
token := &oauth2.Token{AccessToken: accessToken}

View File

@@ -107,7 +107,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
s := &Source{
Config: r,
Client: &client,
client: &client,
}
return s, nil
@@ -117,7 +117,7 @@ var _ sources.Source = &Source{}
type Source struct {
Config
Client *http.Client
client *http.Client
}
func (s *Source) SourceKind() string {
@@ -127,3 +127,19 @@ func (s *Source) SourceKind() string {
func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) HttpDefaultHeaders() map[string]string {
return s.DefaultHeaders
}
func (s *Source) HttpBaseURL() string {
return s.BaseURL
}
func (s *Source) HttpQueryParams() map[string]string {
return s.QueryParams
}
func (s *Source) Client() *http.Client {
return s.client
}

View File

@@ -160,10 +160,6 @@ func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) GetApiSettings() *rtl.ApiSettings {
return s.ApiSettings
}
func (s *Source) UseClientAuthorization() bool {
return strings.ToLower(s.UseClientOAuth) != "false"
}
@@ -188,6 +184,30 @@ func (s *Source) GoogleCloudTokenSourceWithScope(ctx context.Context, scope stri
return google.DefaultTokenSource(ctx, scope)
}
func (s *Source) LookerClient() *v4.LookerSDK {
return s.Client
}
func (s *Source) LookerApiSettings() *rtl.ApiSettings {
return s.ApiSettings
}
func (s *Source) LookerShowHiddenFields() bool {
return s.ShowHiddenFields
}
func (s *Source) LookerShowHiddenModels() bool {
return s.ShowHiddenModels
}
func (s *Source) LookerShowHiddenExplores() bool {
return s.ShowHiddenExplores
}
func (s *Source) LookerSessionLength() int64 {
return s.SessionLength
}
func initGoogleCloudConnection(ctx context.Context) (oauth2.TokenSource, error) {
cred, err := google.FindDefaultCredentials(ctx, geminidataanalytics.DefaultAuthScopes()...)
if err != nil {

View File

@@ -9,9 +9,11 @@ import (
"strings"
"github.com/goccy/go-yaml"
_ "github.com/godror/godror" // OCI driver
_ "github.com/sijms/go-ora/v2" // Pure Go driver
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
_ "github.com/sijms/go-ora/v2"
"go.opentelemetry.io/otel/trace"
)
@@ -32,7 +34,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
return nil, err
}
// Validate that we have one of: tns_alias, connection_string, or host+service_name
// Validate that we have one of: tnsAlias, connectionString, or host+service_name
if err := actual.validate(); err != nil {
return nil, fmt.Errorf("invalid Oracle configuration: %w", err)
}
@@ -43,21 +45,24 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
ConnectionString string `yaml:"connectionString,omitempty"` // Direct connection string (hostname[:port]/servicename)
TnsAlias string `yaml:"tnsAlias,omitempty"` // TNS alias from tnsnames.ora
Host string `yaml:"host,omitempty"` // Optional when using connectionString/tnsAlias
Port int `yaml:"port,omitempty"` // Explicit port support
ServiceName string `yaml:"serviceName,omitempty"` // Optional when using connectionString/tnsAlias
ConnectionString string `yaml:"connectionString,omitempty"`
TnsAlias string `yaml:"tnsAlias,omitempty"`
TnsAdmin string `yaml:"tnsAdmin,omitempty"`
Host string `yaml:"host,omitempty"`
Port int `yaml:"port,omitempty"`
ServiceName string `yaml:"serviceName,omitempty"`
User string `yaml:"user" validate:"required"`
Password string `yaml:"password" validate:"required"`
TnsAdmin string `yaml:"tnsAdmin,omitempty"` // Optional: override TNS_ADMIN environment variable
UseOCI bool `yaml:"useOCI,omitempty"`
WalletLocation string `yaml:"walletLocation,omitempty"`
}
// validate ensures we have one of: tns_alias, connection_string, or host+service_name
func (c Config) validate() error {
hasTnsAdmin := strings.TrimSpace(c.TnsAdmin) != ""
hasTnsAlias := strings.TrimSpace(c.TnsAlias) != ""
hasConnStr := strings.TrimSpace(c.ConnectionString) != ""
hasHostService := strings.TrimSpace(c.Host) != "" && strings.TrimSpace(c.ServiceName) != ""
hasWallet := strings.TrimSpace(c.WalletLocation) != ""
connectionMethods := 0
if hasTnsAlias {
@@ -78,6 +83,14 @@ func (c Config) validate() error {
return fmt.Errorf("provide only one connection method: 'tns_alias', 'connection_string', or 'host'+'service_name'")
}
if hasTnsAdmin && !c.UseOCI {
return fmt.Errorf("`tnsAdmin` can only be used when `UseOCI` is true, or use `walletLocation` instead")
}
if hasWallet && c.UseOCI {
return fmt.Errorf("when using an OCI driver, use `tnsAdmin` to specify credentials file location instead")
}
return nil
}
@@ -132,7 +145,8 @@ func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Confi
panic(err)
}
// Set TNS_ADMIN environment variable if specified in config.
hasWallet := strings.TrimSpace(config.WalletLocation) != ""
if config.TnsAdmin != "" {
originalTnsAdmin := os.Getenv("TNS_ADMIN")
os.Setenv("TNS_ADMIN", config.TnsAdmin)
@@ -147,28 +161,49 @@ func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Confi
}()
}
var serverString string
var connectStringBase string
if config.TnsAlias != "" {
// Use TNS alias
serverString = strings.TrimSpace(config.TnsAlias)
connectStringBase = strings.TrimSpace(config.TnsAlias)
} else if config.ConnectionString != "" {
// Use provided connection string directly (hostname[:port]/servicename format)
serverString = strings.TrimSpace(config.ConnectionString)
connectStringBase = strings.TrimSpace(config.ConnectionString)
} else {
// Build connection string from host and service_name
if config.Port > 0 {
serverString = fmt.Sprintf("%s:%d/%s", config.Host, config.Port, config.ServiceName)
connectStringBase = fmt.Sprintf("%s:%d/%s", config.Host, config.Port, config.ServiceName)
} else {
serverString = fmt.Sprintf("%s/%s", config.Host, config.ServiceName)
connectStringBase = fmt.Sprintf("%s/%s", config.Host, config.ServiceName)
}
}
connStr := fmt.Sprintf("oracle://%s:%s@%s",
config.User, config.Password, serverString)
var driverName string
var finalConnStr string
db, err := sql.Open("oracle", connStr)
if config.UseOCI {
// Use godror driver (requires OCI)
driverName = "godror"
finalConnStr = fmt.Sprintf(`user="%s" password="%s" connectString="%s"`,
config.User, config.Password, connectStringBase)
logger.DebugContext(ctx, fmt.Sprintf("Using godror driver (OCI-based) with connectString: %s\n", connectStringBase))
} else {
// Use go-ora driver (pure Go)
driverName = "oracle"
user := config.User
password := config.Password
if hasWallet {
finalConnStr = fmt.Sprintf("oracle://%s:%s@%s?ssl=true&wallet=%s",
user, password, connectStringBase, config.WalletLocation)
} else {
// Standard go-ora connection
finalConnStr = fmt.Sprintf("oracle://%s:%s@%s",
config.User, config.Password, connectStringBase)
logger.DebugContext(ctx, fmt.Sprintf("Using go-ora driver (pure-Go) with serverString: %s\n", connectStringBase))
}
}
db, err := sql.Open(driverName, finalConnStr)
if err != nil {
return nil, fmt.Errorf("unable to open Oracle connection: %w", err)
return nil, fmt.Errorf("unable to open Oracle connection with driver %s: %w", driverName, err)
}
return db, nil

View File

@@ -0,0 +1,200 @@
// Copyright © 2025, Oracle and/or its affiliates.
package oracle_test
import (
"strings"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources/oracle"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
func TestParseFromYamlOracle(t *testing.T) {
tcs := []struct {
desc string
in string
want server.SourceConfigs
}{
{
desc: "connection string and useOCI=true",
in: `
sources:
my-oracle-cs:
kind: oracle
connectionString: "my-host:1521/XEPDB1"
user: my_user
password: my_pass
useOCI: true
`,
want: server.SourceConfigs{
"my-oracle-cs": oracle.Config{
Name: "my-oracle-cs",
Kind: oracle.SourceKind,
ConnectionString: "my-host:1521/XEPDB1",
User: "my_user",
Password: "my_pass",
UseOCI: true,
},
},
},
{
desc: "host/port/serviceName and default useOCI=false",
in: `
sources:
my-oracle-host:
kind: oracle
host: my-host
port: 1521
serviceName: ORCLPDB
user: my_user
password: my_pass
`,
want: server.SourceConfigs{
"my-oracle-host": oracle.Config{
Name: "my-oracle-host",
Kind: oracle.SourceKind,
Host: "my-host",
Port: 1521,
ServiceName: "ORCLPDB",
User: "my_user",
Password: "my_pass",
UseOCI: false,
},
},
},
{
desc: "tnsAlias and TnsAdmin specified with explicit useOCI=true",
in: `
sources:
my-oracle-tns-oci:
kind: oracle
tnsAlias: FINANCE_DB
tnsAdmin: /opt/oracle/network/admin
user: my_user
password: my_pass
useOCI: true
`,
want: server.SourceConfigs{
"my-oracle-tns-oci": oracle.Config{
Name: "my-oracle-tns-oci",
Kind: oracle.SourceKind,
TnsAlias: "FINANCE_DB",
TnsAdmin: "/opt/oracle/network/admin",
User: "my_user",
Password: "my_pass",
UseOCI: true,
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse:\nwant: %v\ngot: %v\ndiff: %s", tc.want, got.Sources, cmp.Diff(tc.want, got.Sources))
}
})
}
}
func TestFailParseFromYamlOracle(t *testing.T) {
tcs := []struct {
desc string
in string
err string
}{
{
desc: "extra field",
in: `
sources:
my-oracle-instance:
kind: oracle
host: my-host
serviceName: ORCL
user: my_user
password: my_pass
extraField: value
`,
err: "unable to parse source \"my-oracle-instance\" as \"oracle\": [1:1] unknown field \"extraField\"\n> 1 | extraField: value\n ^\n 2 | host: my-host\n 3 | kind: oracle\n 4 | password: my_pass\n 5 | ",
},
{
desc: "missing required password field",
in: `
sources:
my-oracle-instance:
kind: oracle
host: my-host
serviceName: ORCL
user: my_user
`,
err: "unable to parse source \"my-oracle-instance\" as \"oracle\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
},
{
desc: "missing connection method fields (validate fails)",
in: `
sources:
my-oracle-instance:
kind: oracle
user: my_user
password: my_pass
`,
err: "unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: must provide one of: 'tns_alias', 'connection_string', or both 'host' and 'service_name'",
},
{
desc: "multiple connection methods provided (validate fails)",
in: `
sources:
my-oracle-instance:
kind: oracle
host: my-host
serviceName: ORCL
connectionString: "my-host:1521/XEPDB1"
user: my_user
password: my_pass
`,
err: "unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: provide only one connection method: 'tns_alias', 'connection_string', or 'host'+'service_name'",
},
{
desc: "fail on tnsAdmin with useOCI=false",
in: `
sources:
my-oracle-fail:
kind: oracle
tnsAlias: FINANCE_DB
tnsAdmin: /opt/oracle/network/admin
user: my_user
password: my_pass
useOCI: false
`,
err: "unable to parse source \"my-oracle-fail\" as \"oracle\": invalid Oracle configuration: `tnsAdmin` can only be used when `UseOCI` is true, or use `walletLocation` instead",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}
errStr := strings.ReplaceAll(err.Error(), "\r", "")
if errStr != tc.err {
t.Fatalf("unexpected error:\ngot:\n%q\nwant:\n%q\n", errStr, tc.err)
}
})
}
}

View File

@@ -96,6 +96,14 @@ func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) GetProject() string {
return s.Project
}
func (s *Source) GetLocation() string {
return s.Location
}
func (s *Source) GetBatchControllerClient() *dataproc.BatchControllerClient {
return s.Client
}

View File

@@ -46,6 +46,11 @@ func ContextWithNewLogger() (context.Context, error) {
return util.WithLogger(ctx, logger), nil
}
// ContextWithUserAgent creates a new context with a specified user agent string.
func ContextWithUserAgent(ctx context.Context, userAgent string) context.Context {
return util.WithUserAgent(ctx, userAgent)
}
// WaitForString waits until the server logs a single line that matches the provided regex.
// returns the output of whatever the server sent so far.
func WaitForString(ctx context.Context, re *regexp.Regexp, pr io.ReadCloser) (string, error) {

View File

@@ -20,7 +20,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the create-cluster tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -97,7 +102,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -107,7 +111,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the create-cluster tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
@@ -120,6 +123,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
if !ok || project == "" {
@@ -151,7 +159,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'user' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -198,10 +206,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,7 +20,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the create-instance tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -98,7 +103,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -108,7 +112,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the create-instance tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
@@ -121,6 +124,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
if !ok || project == "" {
@@ -147,7 +155,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'instanceType' parameter; expected 'PRIMARY' or 'READ_POOL'")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -208,10 +216,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,7 +20,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the create-user tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -98,7 +103,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -108,9 +112,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the create-user tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -121,6 +123,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
if !ok || project == "" {
@@ -147,7 +154,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing 'userType' parameter; expected 'ALLOYDB_BUILT_IN' or 'ALLOYDB_IAM_USER'")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -208,10 +215,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-get-cluster"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the get-cluster tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -104,7 +109,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the get-cluster tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters
manifest tools.Manifest
@@ -117,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -132,7 +141,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -167,10 +176,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-get-instance"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the get-instance tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -95,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -105,9 +110,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the get-instance tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters
AllParams parameters.Parameters
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -137,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'instance' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -172,10 +180,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-get-user"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the get-user tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -95,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -105,9 +110,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the get-user tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters
AllParams parameters.Parameters
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -137,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'user' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -172,10 +180,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-list-clusters"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the list-clusters tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -93,7 +99,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -103,9 +108,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the list-clusters tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -116,6 +119,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -127,7 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -162,10 +170,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-list-instances"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the list-instances tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -104,9 +109,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the list-instances tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -117,6 +120,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -132,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -167,10 +175,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-list-users"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the list-users tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -104,9 +109,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the list-users tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -117,6 +120,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -132,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -167,10 +175,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -25,9 +25,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-wait-for-operation"
@@ -89,6 +89,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Config defines the configuration for the wait-for-operation tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -119,12 +125,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -180,7 +186,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -194,19 +199,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the wait-for-operation tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
AllParams parameters.Parameters `yaml:"allParams"`
Client *http.Client
manifest tools.Manifest
mcpManifest tools.McpManifest
// Polling configuration
Delay time.Duration
MaxDelay time.Duration
Multiplier float64
MaxRetries int
Client *http.Client
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -215,6 +217,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -230,7 +237,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("missing 'operation' parameter")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -363,10 +370,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -21,7 +21,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"github.com/jackc/pgx/v5/pgxpool"
@@ -47,11 +46,6 @@ type compatibleSource interface {
PostgresPool() *pgxpool.Pool
}
// validate compatible sources are still compatible
var _ compatibleSource = &alloydbpg.Source{}
var compatibleSources = [...]string{alloydbpg.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -70,18 +64,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
numParams := len(cfg.NLConfigParameters)
quotedNameParts := make([]string, 0, numParams)
placeholderParts := make([]string, 0, numParams)
@@ -126,7 +108,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
Config: cfg,
Parameters: cfg.NLConfigParameters,
Statement: stmt,
Pool: s.PostgresPool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.NLConfigParameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -139,9 +120,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters parameters.Parameters `yaml:"parameters"`
Pool *pgxpool.Pool
Parameters parameters.Parameters `yaml:"parameters"`
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
@@ -152,6 +131,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
pool := source.PostgresPool()
sliceParams := params.AsSlice()
allParamValues := make([]any, len(sliceParams)+1)
allParamValues[0] = fmt.Sprintf("%s", sliceParams[0]) // nl_question
@@ -160,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
allParamValues[i+2] = fmt.Sprintf("%s", param)
}
results, err := t.Pool.Query(ctx, t.Statement, allParamValues...)
results, err := pool.Query(ctx, t.Statement, allParamValues...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w. Query: %v , Values: %v. Toolbox v0.19.0+ is only compatible with AlloyDB AI NL v1.0.3+. Please ensure that you are using the latest AlloyDB AI NL extension", err, t.Statement, allParamValues)
}
@@ -203,10 +188,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -57,11 +57,6 @@ type compatibleSource interface {
BigQuerySession() bigqueryds.BigQuerySessionProvider
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -87,7 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
allowedDatasets := s.BigQueryAllowedDatasets()
@@ -136,17 +131,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
SessionProvider: s.BigQuerySession(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -156,17 +144,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
SessionProvider bigqueryds.BigQuerySessionProvider
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -175,23 +155,27 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke runs the contribution analysis.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
inputData, ok := paramsMap["input_data"].(string)
if !ok {
return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"])
}
bqClient := t.Client
restService := t.RestService
var err error
bqClient := source.BigQueryClient()
restService := source.BigQueryRestService()
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = t.ClientCreator(tokenStr, true)
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -229,9 +213,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
var inputDataSource string
trimmedUpperInputData := strings.TrimSpace(strings.ToUpper(inputData))
if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") {
if len(t.AllowedDatasets) > 0 {
if len(source.BigQueryAllowedDatasets()) > 0 {
var connProps []*bigqueryapi.ConnectionProperty
session, err := t.SessionProvider(ctx)
session, err := source.BigQuerySession()(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
@@ -240,7 +224,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
{Key: "session_id", Value: session.ID},
}
}
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, inputData, nil, connProps)
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps)
if err != nil {
return nil, fmt.Errorf("query validation failed: %w", err)
}
@@ -252,7 +236,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
queryStats := dryRunJob.Statistics.Query
if queryStats != nil {
for _, tableRef := range queryStats.ReferencedTables {
if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
return nil, fmt.Errorf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
}
}
@@ -262,18 +246,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
inputDataSource = fmt.Sprintf("(%s)", inputData)
} else {
if len(t.AllowedDatasets) > 0 {
if len(source.BigQueryAllowedDatasets()) > 0 {
parts := strings.Split(inputData, ".")
var projectID, datasetID string
switch len(parts) {
case 3: // project.dataset.table
projectID, datasetID = parts[0], parts[1]
case 2: // dataset.table
projectID, datasetID = t.Client.Project(), parts[0]
projectID, datasetID = source.BigQueryClient().Project(), parts[0]
default:
return nil, fmt.Errorf("invalid table ID format for 'input_data': %q. Expected 'dataset.table' or 'project.dataset.table'", inputData)
}
if !t.IsDatasetAllowed(projectID, datasetID) {
if !source.IsDatasetAllowed(projectID, datasetID) {
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, inputData)
}
}
@@ -292,7 +276,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// Get session from provider if in protected mode.
// Otherwise, a new session will be created by the first query.
session, err := t.SessionProvider(ctx)
session, err := source.BigQuerySession()(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
@@ -385,10 +369,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -26,7 +26,6 @@ import (
bigqueryapi "cloud.google.com/go/bigquery"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -105,11 +104,6 @@ type CAPayload struct {
ClientIdEnum string `json:"clientIdEnum"`
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -135,7 +129,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
allowedDatasets := s.BigQueryAllowedDatasets()
@@ -153,31 +147,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
params := parameters.Parameters{userQueryParameter, tableRefsParameter}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
// Get cloud-platform token source for Gemini Data Analytics API during initialization
var bigQueryTokenSourceWithScope oauth2.TokenSource
if !s.UseClientAuthorization() {
ctx := context.Background()
ts, err := s.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err)
}
bigQueryTokenSourceWithScope = ts
}
// finish tool setup
t := Tool{
Config: cfg,
Project: s.BigQueryProject(),
Location: s.BigQueryLocation(),
Parameters: params,
Client: s.BigQueryClient(),
UseClientOAuth: s.UseClientAuthorization(),
TokenSource: bigQueryTokenSourceWithScope,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
MaxQueryResultRows: s.GetMaxQueryResultRows(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -187,18 +162,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project string
Location string
Client *bigqueryapi.Client
TokenSource oauth2.TokenSource
manifest tools.Manifest
mcpManifest tools.McpManifest
MaxQueryResultRows int
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -206,11 +172,15 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
var tokenStr string
var err error
// Get credentials for the API call
if t.UseClientOAuth {
if source.UseClientAuthorization() {
// Use client-side access token
if accessToken == "" {
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", util.ErrUnauthorized)
@@ -220,11 +190,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("error parsing access token: %w", err)
}
} else {
// Get cloud-platform token source for Gemini Data Analytics API during initialization
tokenSource, err := source.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err)
}
// Use cloud-platform token source for Gemini Data Analytics API
if t.TokenSource == nil {
if tokenSource == nil {
return nil, fmt.Errorf("cloud-platform token source is missing")
}
token, err := t.TokenSource.Token()
token, err := tokenSource.Token()
if err != nil {
return nil, fmt.Errorf("failed to get token from cloud-platform token source: %w", err)
}
@@ -245,17 +221,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
if len(t.AllowedDatasets) > 0 {
if len(source.BigQueryAllowedDatasets()) > 0 {
for _, tableRef := range tableRefs {
if !t.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
if !source.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID)
}
}
}
// Construct URL, headers, and payload
projectID := t.Project
location := t.Location
projectID := source.BigQueryProject()
location := source.BigQueryLocation()
if location == "" {
location = "us"
}
@@ -279,7 +255,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
// Call the streaming API
response, err := getStream(caURL, payload, headers, t.MaxQueryResultRows)
response, err := getStream(caURL, payload, headers, source.GetMaxQueryResultRows())
if err != nil {
return nil, fmt.Errorf("failed to get response from conversational analytics API: %w", err)
}
@@ -303,8 +279,12 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
// StreamMessage represents a single message object from the streaming API response.
@@ -580,6 +560,6 @@ func appendMessage(messages []map[string]any, newMessage map[string]any) []map[s
return append(messages, newMessage)
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -60,11 +60,6 @@ type compatibleSource interface {
BigQueryAllowedDatasets() []string
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -90,7 +85,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
var sqlDescriptionBuilder strings.Builder
@@ -136,18 +131,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
WriteMode: s.BigQueryWriteMode(),
SessionProvider: s.BigQuerySession(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -157,18 +144,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
WriteMode string
SessionProvider bigqueryds.BigQuerySessionProvider
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -176,6 +154,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
sql, ok := paramsMap["sql"].(string)
if !ok {
@@ -186,17 +169,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"])
}
bqClient := t.Client
restService := t.RestService
bqClient := source.BigQueryClient()
restService := source.BigQueryRestService()
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = t.ClientCreator(tokenStr, true)
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -204,8 +186,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
var connProps []*bigqueryapi.ConnectionProperty
var session *bigqueryds.Session
if t.WriteMode == bigqueryds.WriteModeProtected {
session, err = t.SessionProvider(ctx)
if source.BigQueryWriteMode() == bigqueryds.WriteModeProtected {
session, err = source.BigQuerySession()(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session for protected mode: %w", err)
}
@@ -221,7 +203,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
statementType := dryRunJob.Statistics.Query.StatementType
switch t.WriteMode {
switch source.BigQueryWriteMode() {
case bigqueryds.WriteModeBlocked:
if statementType != "SELECT" {
return nil, fmt.Errorf("write mode is 'blocked', only SELECT statements are allowed")
@@ -235,7 +217,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
if len(t.AllowedDatasets) > 0 {
if len(source.BigQueryAllowedDatasets()) > 0 {
switch statementType {
case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA":
return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType)
@@ -270,7 +252,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
} else if statementType != "SELECT" {
// If dry run yields no tables, fall back to the parser for non-SELECT statements
// to catch unsafe operations like EXECUTE IMMEDIATE.
parsedTables, parseErr := bqutil.TableParser(sql, t.Client.Project())
parsedTables, parseErr := bqutil.TableParser(sql, source.BigQueryClient().Project())
if parseErr != nil {
// If parsing fails (e.g., EXECUTE IMMEDIATE), we cannot guarantee safety, so we must fail.
return nil, fmt.Errorf("could not parse tables from query to validate against allowed datasets: %w", parseErr)
@@ -282,7 +264,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
parts := strings.Split(tableID, ".")
if len(parts) == 3 {
projectID, datasetID := parts[0], parts[1]
if !t.IsDatasetAllowed(projectID, datasetID) {
if !source.IsDatasetAllowed(projectID, datasetID) {
return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID)
}
}
@@ -374,10 +356,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -57,11 +57,6 @@ type compatibleSource interface {
BigQuerySession() bigqueryds.BigQuerySessionProvider
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -87,7 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
allowedDatasets := s.BigQueryAllowedDatasets()
@@ -116,17 +111,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
IsDatasetAllowed: s.IsDatasetAllowed,
SessionProvider: s.BigQuerySession(),
AllowedDatasets: allowedDatasets,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -136,17 +124,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
SessionProvider bigqueryds.BigQuerySessionProvider
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -154,6 +134,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
historyData, ok := paramsMap["history_data"].(string)
if !ok {
@@ -188,17 +173,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
bqClient := t.Client
restService := t.RestService
var err error
bqClient := source.BigQueryClient()
restService := source.BigQueryRestService()
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = t.ClientCreator(tokenStr, false)
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -207,9 +191,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
var historyDataSource string
trimmedUpperHistoryData := strings.TrimSpace(strings.ToUpper(historyData))
if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") {
if len(t.AllowedDatasets) > 0 {
if len(source.BigQueryAllowedDatasets()) > 0 {
var connProps []*bigqueryapi.ConnectionProperty
session, err := t.SessionProvider(ctx)
session, err := source.BigQuerySession()(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
@@ -218,7 +202,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
{Key: "session_id", Value: session.ID},
}
}
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, historyData, nil, connProps)
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps)
if err != nil {
return nil, fmt.Errorf("query validation failed: %w", err)
}
@@ -230,7 +214,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
queryStats := dryRunJob.Statistics.Query
if queryStats != nil {
for _, tableRef := range queryStats.ReferencedTables {
if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
return nil, fmt.Errorf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
}
}
@@ -240,7 +224,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
historyDataSource = fmt.Sprintf("(%s)", historyData)
} else {
if len(t.AllowedDatasets) > 0 {
if len(source.BigQueryAllowedDatasets()) > 0 {
parts := strings.Split(historyData, ".")
var projectID, datasetID string
@@ -249,13 +233,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
projectID = parts[0]
datasetID = parts[1]
case 2: // dataset.table
projectID = t.Client.Project()
projectID = source.BigQueryClient().Project()
datasetID = parts[0]
default:
return nil, fmt.Errorf("invalid table ID format for 'history_data': %q. Expected 'dataset.table' or 'project.dataset.table'", historyData)
}
if !t.IsDatasetAllowed(projectID, datasetID) {
if !source.IsDatasetAllowed(projectID, datasetID) {
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, historyData)
}
}
@@ -279,7 +263,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// JobStatistics.QueryStatistics.StatementType
query := bqClient.Query(sql)
query.Location = bqClient.Location
session, err := t.SessionProvider(ctx)
session, err := source.BigQuerySession()(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
@@ -349,10 +333,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -54,11 +54,6 @@ type compatibleSource interface {
BigQueryAllowedDatasets() []string
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -84,7 +79,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
defaultProjectID := s.BigQueryProject()
@@ -104,14 +99,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
IsDatasetAllowed: s.IsDatasetAllowed,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -121,15 +112,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
Statement string
IsDatasetAllowed func(projectID, datasetID string) bool
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -137,6 +122,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
mapParams := params.AsMap()
projectId, ok := mapParams[projectKey].(string)
if !ok {
@@ -148,22 +138,21 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
}
bqClient := t.Client
var err error
bqClient := source.BigQueryClient()
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}
if !t.IsDatasetAllowed(projectId, datasetId) {
if !source.IsDatasetAllowed(projectId, datasetId) {
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
}
@@ -193,10 +182,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -55,11 +55,6 @@ type compatibleSource interface {
BigQueryAllowedDatasets() []string
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -85,7 +80,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
defaultProjectID := s.BigQueryProject()
@@ -108,14 +103,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
IsDatasetAllowed: s.IsDatasetAllowed,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -125,15 +116,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
Statement string
IsDatasetAllowed func(projectID, datasetID string) bool
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -141,6 +126,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
mapParams := params.AsMap()
projectId, ok := mapParams[projectKey].(string)
if !ok {
@@ -157,20 +147,19 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", tableKey)
}
if !t.IsDatasetAllowed(projectId, datasetId) {
if !source.IsDatasetAllowed(projectId, datasetId) {
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
}
bqClient := t.Client
bqClient := source.BigQueryClient()
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -203,10 +192,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -52,11 +52,6 @@ type compatibleSource interface {
BigQueryAllowedDatasets() []string
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -82,7 +77,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
var projectParameter parameters.Parameter
@@ -103,14 +98,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
AllowedDatasets: allowedDatasets,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -120,15 +111,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
Statement string
AllowedDatasets []string
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -136,8 +121,13 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
if len(t.AllowedDatasets) > 0 {
return t.AllowedDatasets, nil
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
if len(source.BigQueryAllowedDatasets()) > 0 {
return source.BigQueryAllowedDatasets(), nil
}
mapParams := params.AsMap()
projectId, ok := mapParams[projectKey].(string)
@@ -145,14 +135,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
}
bqClient := t.Client
bqClient := source.BigQueryClient()
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -197,10 +187,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -55,11 +55,6 @@ type compatibleSource interface {
BigQueryAllowedDatasets() []string
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -85,7 +80,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
defaultProjectID := s.BigQueryProject()
@@ -107,14 +102,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
IsDatasetAllowed: s.IsDatasetAllowed,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -124,15 +115,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -140,6 +125,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
mapParams := params.AsMap()
projectId, ok := mapParams[projectKey].(string)
if !ok {
@@ -151,18 +141,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
}
if !t.IsDatasetAllowed(projectId, datasetId) {
if !source.IsDatasetAllowed(projectId, datasetId) {
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
}
bqClient := t.Client
bqClient := source.BigQueryClient()
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -208,10 +198,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -51,11 +51,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -72,20 +67,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// Initialize the search configuration with the provided sources
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
// Get the Dataplex client using the method from the source
makeCatalogClient := s.MakeDataplexCatalogClient()
prompt := parameters.NewStringParameter("prompt", "Prompt representing search intention. Do not rewrite the prompt.")
datasetIds := parameters.NewArrayParameterWithDefault("datasetIds", []any{}, "Array of dataset IDs.", parameters.NewStringParameter("datasetId", "The IDs of the bigquery dataset."))
projectIds := parameters.NewArrayParameterWithDefault("projectIds", []any{}, "Array of project IDs.", parameters.NewStringParameter("projectId", "The IDs of the bigquery project."))
@@ -100,11 +81,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, params, nil)
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
MakeCatalogClient: makeCatalogClient,
ProjectID: s.BigQueryProject(),
Config: cfg,
Parameters: params,
manifest: tools.Manifest{
Description: cfg.Description,
Parameters: params.Manifest(),
@@ -117,12 +95,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
type Tool struct {
Config
Parameters parameters.Parameters
UseClientOAuth bool
MakeCatalogClient func() (*dataplexapi.CatalogClient, bigqueryds.DataplexClientCreator, error)
ProjectID string
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -133,8 +108,12 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func constructSearchQueryHelper(predicate string, operator string, items []string) string {
@@ -207,6 +186,11 @@ func ExtractType(resourceString string) string {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
pageSize := int32(paramsMap["pageSize"].(int))
prompt, _ := paramsMap["prompt"].(string)
@@ -228,14 +212,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
req := &dataplexpb.SearchEntriesRequest{
Query: fmt.Sprintf("%s %s", prompt, constructSearchQuery(projectIds, datasetIds, types)),
Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID),
Name: fmt.Sprintf("projects/%s/locations/global", source.BigQueryProject()),
PageSize: pageSize,
SemanticSearch: true,
}
catalogClient, dataplexClientCreator, _ := t.MakeCatalogClient()
catalogClient, dataplexClientCreator, _ := source.MakeDataplexCatalogClient()()
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
@@ -248,7 +232,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
it := catalogClient.SearchEntries(ctx, req)
if it == nil {
return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID)
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.BigQueryProject())
}
var results []Response
@@ -288,6 +272,6 @@ func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -57,11 +57,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -81,18 +76,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
if err != nil {
return nil, err
@@ -102,15 +85,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
AllParams: allParameters,
UseClientOAuth: s.UseClientAuthorization(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
SessionProvider: s.BigQuerySession(),
ClientCreator: s.BigQueryClientCreator(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -120,15 +98,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
AllParams parameters.Parameters `yaml:"allParams"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
SessionProvider bigqueryds.BigQuerySessionProvider
ClientCreator bigqueryds.BigqueryClientCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -136,6 +108,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters))
@@ -212,16 +189,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
lowLevelParams = append(lowLevelParams, lowLevelParam)
}
bqClient := t.Client
restService := t.RestService
bqClient := source.BigQueryClient()
restService := source.BigQueryRestService()
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = t.ClientCreator(tokenStr, true)
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -232,8 +209,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
query.Location = bqClient.Location
connProps := []*bigqueryapi.ConnectionProperty{}
if t.SessionProvider != nil {
session, err := t.SessionProvider(ctx)
if source.BigQuerySession() != nil {
session, err := source.BigQuerySession()(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
@@ -311,10 +288,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -21,7 +21,6 @@ import (
"cloud.google.com/go/bigtable"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
bigtabledb "github.com/googleapis/genai-toolbox/internal/sources/bigtable"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -46,11 +45,6 @@ type compatibleSource interface {
BigtableClient() *bigtable.Client
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigtabledb.Source{}
var compatibleSources = [...]string{bigtabledb.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -70,18 +64,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
if err != nil {
return nil, err
@@ -93,7 +75,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
t := Tool{
Config: cfg,
AllParams: allParameters,
Client: s.BigtableClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -105,9 +86,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
Client *bigtable.Client
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -156,6 +135,11 @@ func getMapParamsType(tparams parameters.Parameters, params parameters.ParamValu
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {
@@ -172,7 +156,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("fail to get map params: %w", err)
}
ps, err := t.Client.PrepareStatement(
ps, err := source.BigtableClient().PrepareStatement(
ctx,
newStatement,
mapParamsType,
@@ -224,10 +208,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -21,7 +21,6 @@ import (
gocql "github.com/apache/cassandra-gocql-driver/v2"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cassandra"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -46,10 +45,6 @@ type compatibleSource interface {
CassandraSession() *gocql.Session
}
var _ compatibleSource = &cassandra.Source{}
var compatibleSources = [...]string{cassandra.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -61,20 +56,15 @@ type Config struct {
TemplateParameters parameters.Parameters `yaml:"templateParameters"`
}
var _ tools.ToolConfig = Config{}
// ToolConfigKind implements tools.ToolConfig.
func (c Config) ToolConfigKind() string {
return kind
}
// Initialize implements tools.ToolConfig.
func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[c.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", c.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
allParameters, paramManifest, err := parameters.ProcessParameters(c.TemplateParameters, c.Parameters)
if err != nil {
return nil, err
@@ -85,25 +75,17 @@ func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
t := Tool{
Config: c,
AllParams: allParameters,
Session: s.CassandraSession(),
manifest: tools.Manifest{Description: c.Description, Parameters: paramManifest, AuthRequired: c.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
// ToolConfigKind implements tools.ToolConfig.
func (c Config) ToolConfigKind() string {
return kind
}
var _ tools.ToolConfig = Config{}
var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
Session *gocql.Session
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -113,8 +95,8 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
// RequiresClientAuthorization implements tools.Tool.
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
// Authorized implements tools.Tool.
@@ -124,6 +106,11 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
// Invoke implements tools.Tool.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {
@@ -135,7 +122,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("unable to extract standard params %w", err)
}
sliceParams := newParams.AsSlice()
iter := t.Session.Query(newStatement, sliceParams...).IterContext(ctx)
iter := source.CassandraSession().Query(newStatement, sliceParams...).IterContext(ctx)
// Create a slice to store the out
var out []map[string]interface{}
@@ -170,8 +157,6 @@ func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any)
return parameters.ParseParams(t.AllParams, data, claims)
}
var _ tools.Tool = Tool{}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -25,12 +25,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
type compatibleSource interface {
ClickHousePool() *sql.DB
}
var compatibleSources = []string{"clickhouse"}
const executeSQLKind string = "clickhouse-execute-sql"
func init() {
@@ -47,6 +41,10 @@ func newExecuteSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder
return actual, nil
}
type compatibleSource interface {
ClickHousePool() *sql.DB
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -62,16 +60,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", executeSQLKind, compatibleSources)
}
sqlParameter := parameters.NewStringParameter("sql", "The SQL statement to execute.")
params := parameters.Parameters{sqlParameter}
@@ -80,7 +68,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
t := Tool{
Config: cfg,
Parameters: params,
Pool: s.ClickHousePool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -91,9 +78,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters parameters.Parameters `yaml:"parameters"`
Pool *sql.DB
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -103,13 +88,18 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
sql, ok := paramsMap["sql"].(string)
if !ok {
return nil, fmt.Errorf("unable to cast sql parameter %s", paramsMap["sql"])
}
results, err := t.Pool.QueryContext(ctx, sql)
results, err := source.ClickHousePool().QueryContext(ctx, sql)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
@@ -183,10 +173,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -25,12 +25,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
type compatibleSource interface {
ClickHousePool() *sql.DB
}
var compatibleSources = []string{"clickhouse"}
const listDatabasesKind string = "clickhouse-list-databases"
func init() {
@@ -47,6 +41,10 @@ func newListDatabasesConfig(ctx context.Context, name string, decoder *yaml.Deco
return actual, nil
}
type compatibleSource interface {
ClickHousePool() *sql.DB
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -63,23 +61,12 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", listDatabasesKind, compatibleSources)
}
allParameters, paramManifest, _ := parameters.ProcessParameters(nil, cfg.Parameters)
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
t := Tool{
Config: cfg,
AllParams: allParameters,
Pool: s.ClickHousePool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -90,9 +77,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
Pool *sql.DB
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -102,10 +87,15 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
// Query to list all databases
query := "SHOW DATABASES"
results, err := t.Pool.QueryContext(ctx, query)
results, err := source.ClickHousePool().QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
@@ -146,10 +136,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,7 +20,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -32,21 +31,6 @@ func TestListDatabasesConfigToolConfigKind(t *testing.T) {
}
}
func TestListDatabasesConfigInitializeMissingSource(t *testing.T) {
cfg := Config{
Name: "test-list-databases",
Kind: listDatabasesKind,
Source: "missing-source",
Description: "Test list databases tool",
}
srcs := map[string]sources.Source{}
_, err := cfg.Initialize(srcs)
if err == nil {
t.Error("expected error for missing source")
}
}
func TestParseFromYamlClickHouseListDatabases(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {

View File

@@ -25,12 +25,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
type compatibleSource interface {
ClickHousePool() *sql.DB
}
var compatibleSources = []string{"clickhouse"}
const listTablesKind string = "clickhouse-list-tables"
const databaseKey string = "database"
@@ -48,6 +42,10 @@ func newListTablesConfig(ctx context.Context, name string, decoder *yaml.Decoder
return actual, nil
}
type compatibleSource interface {
ClickHousePool() *sql.DB
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -64,16 +62,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", listTablesKind, compatibleSources)
}
databaseParameter := parameters.NewStringParameter(databaseKey, "The database to list tables from.")
params := parameters.Parameters{databaseParameter}
@@ -83,7 +71,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
t := Tool{
Config: cfg,
AllParams: allParameters,
Pool: s.ClickHousePool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -94,9 +81,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
Pool *sql.DB
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -106,6 +91,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
mapParams := params.AsMap()
database, ok := mapParams[databaseKey].(string)
if !ok {
@@ -115,7 +105,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// Query to list all tables in the specified database
query := fmt.Sprintf("SHOW TABLES FROM %s", database)
results, err := t.Pool.QueryContext(ctx, query)
results, err := source.ClickHousePool().QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
@@ -157,10 +147,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,7 +20,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -32,21 +31,6 @@ func TestListTablesConfigToolConfigKind(t *testing.T) {
}
}
func TestListTablesConfigInitializeMissingSource(t *testing.T) {
cfg := Config{
Name: "test-list-tables",
Kind: listTablesKind,
Source: "missing-source",
Description: "Test list tables tool",
}
srcs := map[string]sources.Source{}
_, err := cfg.Initialize(srcs)
if err == nil {
t.Error("expected error for missing source")
}
}
func TestParseFromYamlClickHouseListTables(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {

View File

@@ -25,21 +25,15 @@ import (
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
type compatibleSource interface {
ClickHousePool() *sql.DB
}
var compatibleSources = []string{"clickhouse"}
const sqlKind string = "clickhouse-sql"
func init() {
if !tools.Register(sqlKind, newSQLConfig) {
if !tools.Register(sqlKind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", sqlKind))
}
}
func newSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
@@ -47,6 +41,10 @@ func newSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tool
return actual, nil
}
type compatibleSource interface {
ClickHousePool() *sql.DB
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -65,23 +63,12 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", sqlKind, compatibleSources)
}
allParameters, paramManifest, _ := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
t := Tool{
Config: cfg,
AllParams: allParameters,
Pool: s.ClickHousePool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -93,7 +80,6 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
Pool *sql.DB
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -103,6 +89,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {
@@ -115,7 +106,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
sliceParams := newParams.AsSlice()
results, err := t.Pool.QueryContext(ctx, newStatement, sliceParams...)
results, err := source.ClickHousePool().QueryContext(ctx, newStatement, sliceParams...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
@@ -191,10 +182,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -142,66 +142,6 @@ func TestSQLConfigInitializeValidSource(t *testing.T) {
}
}
func TestSQLConfigInitializeMissingSource(t *testing.T) {
config := Config{
Name: "test-tool",
Kind: sqlKind,
Source: "missing-source",
Description: "Test tool",
Statement: "SELECT 1",
Parameters: parameters.Parameters{},
}
sources := map[string]sources.Source{}
_, err := config.Initialize(sources)
if err == nil {
t.Fatal("Expected error for missing source, got nil")
}
expectedErr := `no source named "missing-source" configured`
if err.Error() != expectedErr {
t.Errorf("Expected error %q, got %q", expectedErr, err.Error())
}
}
// mockIncompatibleSource is a mock source that doesn't implement the compatibleSource interface
type mockIncompatibleSource struct{}
func (m *mockIncompatibleSource) SourceKind() string {
return "mock"
}
func (m *mockIncompatibleSource) ToConfig() sources.SourceConfig {
return nil
}
func TestSQLConfigInitializeIncompatibleSource(t *testing.T) {
config := Config{
Name: "test-tool",
Kind: sqlKind,
Source: "incompatible-source",
Description: "Test tool",
Statement: "SELECT 1",
Parameters: parameters.Parameters{},
}
mockSource := &mockIncompatibleSource{}
sources := map[string]sources.Source{
"incompatible-source": mockSource,
}
_, err := config.Initialize(sources)
if err == nil {
t.Fatal("Expected error for incompatible source, got nil")
}
if err.Error() == "" {
t.Error("Expected non-empty error message")
}
}
func TestToolManifest(t *testing.T) {
tool := Tool{
manifest: tools.Manifest{

View File

@@ -0,0 +1,206 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudgda
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
const kind string = "cloud-gemini-data-analytics-query"
func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
}
}
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
return actual, nil
}
type compatibleSource interface {
GetProjectID() string
GetBaseURL() string
UseClientAuthorization() bool
GetClient(context.Context, string) (*http.Client, error)
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description" validate:"required"`
Location string `yaml:"location" validate:"required"`
Context *QueryDataContext `yaml:"context" validate:"required"`
GenerationOptions *GenerationOptions `yaml:"generationOptions,omitempty"`
AuthRequired []string `yaml:"authRequired"`
}
// validate interface
var _ tools.ToolConfig = Config{}
func (cfg Config) ToolConfigKind() string {
return kind
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// Define the parameters for the Gemini Data Analytics Query API
// The prompt is the only input parameter.
allParameters := parameters.Parameters{
parameters.NewStringParameterWithRequired("prompt", "The natural language question to ask.", true),
}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
return Tool{
Config: cfg,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}, nil
}
// validate interface
var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
return t.Config
}
// Invoke executes the tool logic
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
prompt, ok := paramsMap["prompt"].(string)
if !ok {
return nil, fmt.Errorf("prompt parameter not found or not a string")
}
// The API endpoint itself always uses the "global" location.
apiLocation := "global"
apiParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), apiLocation)
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", source.GetBaseURL(), apiParent)
// The parent in the request payload uses the tool's configured location.
payloadParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), t.Location)
payload := &QueryDataRequest{
Parent: payloadParent,
Prompt: prompt,
Context: t.Context,
GenerationOptions: t.GenerationOptions,
}
bodyBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request payload: %w", err)
}
// Parse the access token if provided
var tokenStr string
if source.UseClientAuthorization() {
var err error
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
}
client, err := source.GetClient(ctx, tokenStr)
if err != nil {
return nil, fmt.Errorf("failed to get HTTP client: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody))
}
var result map[string]any
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
return result, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
return parameters.ParseParams(t.AllParams, data, claims)
}
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName(_ tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -0,0 +1,353 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudgda_test
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/server/resources"
"github.com/googleapis/genai-toolbox/internal/sources"
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools"
cloudgdatool "github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
func TestParseFromYaml(t *testing.T) {
t.Parallel()
tcs := []struct {
desc string
in string
want server.ToolConfigs
}{
{
desc: "basic example",
in: `
tools:
my-gda-query-tool:
kind: cloud-gemini-data-analytics-query
source: gda-api-source
description: Test Description
location: us-central1
context:
datasourceReferences:
spannerReference:
databaseReference:
projectId: "cloud-db-nl2sql"
region: "us-central1"
instanceId: "evalbench"
databaseId: "financial"
engine: "GOOGLE_SQL"
agentContextReference:
contextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates"
generationOptions:
generateQueryResult: true
`,
want: map[string]tools.ToolConfig{
"my-gda-query-tool": cloudgdatool.Config{
Name: "my-gda-query-tool",
Kind: "cloud-gemini-data-analytics-query",
Source: "gda-api-source",
Description: "Test Description",
Location: "us-central1",
AuthRequired: []string{},
Context: &cloudgdatool.QueryDataContext{
DatasourceReferences: &cloudgdatool.DatasourceReferences{
SpannerReference: &cloudgdatool.SpannerReference{
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
ProjectID: "cloud-db-nl2sql",
Region: "us-central1",
InstanceID: "evalbench",
DatabaseID: "financial",
Engine: cloudgdatool.SpannerEngineGoogleSQL,
},
AgentContextReference: &cloudgdatool.AgentContextReference{
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
},
},
},
},
GenerationOptions: &cloudgdatool.GenerationOptions{
GenerateQueryResult: true,
},
},
},
},
}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Tools) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Tools)
}
})
}
}
// authRoundTripper is a mock http.RoundTripper that adds a dummy Authorization header.
type authRoundTripper struct {
Token string
Next http.RoundTripper
}
func (rt *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
newReq.Header = make(http.Header)
for k, v := range req.Header {
newReq.Header[k] = v
}
newReq.Header.Set("Authorization", rt.Token)
if rt.Next == nil {
return http.DefaultTransport.RoundTrip(&newReq)
}
return rt.Next.RoundTrip(&newReq)
}
type mockSource struct {
kind string
client *http.Client // Can be used to inject a specific client
baseURL string // BaseURL is needed to implement sources.Source.BaseURL
config cloudgdasrc.Config // to return from ToConfig
}
func (m *mockSource) SourceKind() string { return m.kind }
func (m *mockSource) ToConfig() sources.SourceConfig { return m.config }
func (m *mockSource) GetClient(ctx context.Context, token string) (*http.Client, error) {
if m.client != nil {
return m.client, nil
}
// Default client for testing if not explicitly set
transport := &http.Transport{}
authTransport := &authRoundTripper{
Token: "Bearer test-access-token", // Dummy token
Next: transport,
}
return &http.Client{Transport: authTransport}, nil
}
func (m *mockSource) UseClientAuthorization() bool { return false }
func (m *mockSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) {
return m, nil
}
func (m *mockSource) BaseURL() string { return m.baseURL }
func TestInitialize(t *testing.T) {
t.Parallel()
srcs := map[string]sources.Source{
"gda-api-source": &cloudgdasrc.Source{
Config: cloudgdasrc.Config{Name: "gda-api-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
Client: &http.Client{},
BaseURL: cloudgdasrc.Endpoint,
},
}
tcs := []struct {
desc string
cfg cloudgdatool.Config
}{
{
desc: "successful initialization",
cfg: cloudgdatool.Config{
Name: "my-gda-query-tool",
Kind: "cloud-gemini-data-analytics-query",
Source: "gda-api-source",
Description: "Test Description",
Location: "us-central1",
},
},
}
// Add an incompatible source for testing
srcs["incompatible-source"] = &mockSource{kind: "another-kind"}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
tool, err := tc.cfg.Initialize(srcs)
if err != nil {
t.Fatalf("did not expect an error but got: %v", err)
}
// Basic sanity check on the returned tool
_ = tool // Avoid unused variable error
})
}
}
func TestInvoke(t *testing.T) {
t.Parallel()
// Mock the HTTP client and server for Invoke testing
serverMux := http.NewServeMux()
// Update expected URL path to include the location "us-central1"
serverMux.HandleFunc("/v1beta/projects/test-project/locations/global:queryData", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST method, got %s", r.Method)
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
// Read and unmarshal the request body
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("failed to read request body: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
var reqPayload cloudgdatool.QueryDataRequest
if err := json.Unmarshal(bodyBytes, &reqPayload); err != nil {
t.Errorf("failed to unmarshal request payload: %v", err)
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
// Verify expected fields
if r.Header.Get("Authorization") == "" {
t.Errorf("expected Authorization header, got empty")
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
if reqPayload.Prompt != "How many accounts who have region in Prague are eligible for loans?" {
t.Errorf("unexpected prompt: %s", reqPayload.Prompt)
}
// Verify payload's parent uses the tool's configured location
if reqPayload.Parent != fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1") {
t.Errorf("unexpected payload parent: got %q, want %q", reqPayload.Parent, fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1"))
}
// Verify context from config
if reqPayload.Context == nil ||
reqPayload.Context.DatasourceReferences == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference.ProjectID != "cloud-db-nl2sql" {
t.Errorf("unexpected context: %v", reqPayload.Context)
}
// Verify generation options from config
if reqPayload.GenerationOptions == nil || !reqPayload.GenerationOptions.GenerateQueryResult {
t.Errorf("unexpected generation options: %v", reqPayload.GenerationOptions)
}
// Simulate a successful response
resp := map[string]any{
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
}
_ = json.NewEncoder(w).Encode(resp)
})
mockServer := httptest.NewServer(serverMux)
defer mockServer.Close()
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
// Create an authenticated client that uses the mock server
authTransport := &authRoundTripper{
Token: "Bearer test-access-token",
Next: mockServer.Client().Transport,
}
authClient := &http.Client{Transport: authTransport}
// Create a real cloudgdasrc.Source but inject the authenticated client
mockGdaSource := &cloudgdasrc.Source{
Config: cloudgdasrc.Config{Name: "mock-gda-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
Client: authClient,
BaseURL: mockServer.URL,
}
srcs := map[string]sources.Source{
"mock-gda-source": mockGdaSource,
}
// Initialize the tool config with context
toolCfg := cloudgdatool.Config{
Name: "query-data-tool",
Kind: "cloud-gemini-data-analytics-query",
Source: "mock-gda-source",
Description: "Query Gemini Data Analytics",
Location: "us-central1", // Set location for the test
Context: &cloudgdatool.QueryDataContext{
DatasourceReferences: &cloudgdatool.DatasourceReferences{
SpannerReference: &cloudgdatool.SpannerReference{
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
ProjectID: "cloud-db-nl2sql",
Region: "us-central1",
InstanceID: "evalbench",
DatabaseID: "financial",
Engine: cloudgdatool.SpannerEngineGoogleSQL,
},
AgentContextReference: &cloudgdatool.AgentContextReference{
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
},
},
},
},
GenerationOptions: &cloudgdatool.GenerationOptions{
GenerateQueryResult: true,
},
}
tool, err := toolCfg.Initialize(srcs)
if err != nil {
t.Fatalf("failed to initialize tool: %v", err)
}
// Prepare parameters for invocation - ONLY prompt
params := parameters.ParamValues{
{Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"},
}
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil)
// Invoke the tool
result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client
if err != nil {
t.Fatalf("tool invocation failed: %v", err)
}
// Validate the result
expectedResult := map[string]any{
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
}
if !cmp.Equal(expectedResult, result) {
t.Errorf("unexpected result: got %v, want %v", result, expectedResult)
}
}

View File

@@ -0,0 +1,116 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudgda
// See full service definition at: https://github.com/googleapis/googleapis/blob/master/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto
// QueryDataRequest represents the JSON body for the queryData API
type QueryDataRequest struct {
Parent string `json:"parent"`
Prompt string `json:"prompt"`
Context *QueryDataContext `json:"context,omitempty"`
GenerationOptions *GenerationOptions `json:"generationOptions,omitempty"`
}
// QueryDataContext reflects the proto definition for the query context.
type QueryDataContext struct {
DatasourceReferences *DatasourceReferences `json:"datasourceReferences,omitempty" yaml:"datasourceReferences,omitempty"`
}
// DatasourceReferences reflects the proto definition for datasource references, using a oneof.
type DatasourceReferences struct {
SpannerReference *SpannerReference `json:"spannerReference,omitempty" yaml:"spannerReference,omitempty"`
AlloyDBReference *AlloyDBReference `json:"alloydb,omitempty" yaml:"alloydb,omitempty"`
CloudSQLReference *CloudSQLReference `json:"cloudSqlReference,omitempty" yaml:"cloudSqlReference,omitempty"`
}
// SpannerReference reflects the proto definition for Spanner database reference.
type SpannerReference struct {
DatabaseReference *SpannerDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// SpannerDatabaseReference reflects the proto definition for a Spanner database reference.
type SpannerDatabaseReference struct {
Engine SpannerEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// SpannerEngine represents the engine of the Spanner instance.
type SpannerEngine string
const (
SpannerEngineUnspecified SpannerEngine = "ENGINE_UNSPECIFIED"
SpannerEngineGoogleSQL SpannerEngine = "GOOGLE_SQL"
SpannerEnginePostgreSQL SpannerEngine = "POSTGRESQL"
)
// AlloyDBReference reflects the proto definition for an AlloyDB database reference.
type AlloyDBReference struct {
DatabaseReference *AlloyDBDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// AlloyDBDatabaseReference reflects the proto definition for an AlloyDB database reference.
type AlloyDBDatabaseReference struct {
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
ClusterID string `json:"clusterId,omitempty" yaml:"clusterId,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// CloudSQLReference reflects the proto definition for a Cloud SQL database reference.
type CloudSQLReference struct {
DatabaseReference *CloudSQLDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// CloudSQLDatabaseReference reflects the proto definition for a Cloud SQL database reference.
type CloudSQLDatabaseReference struct {
Engine CloudSQLEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// CloudSQLEngine represents the engine of the Cloud SQL instance.
type CloudSQLEngine string
const (
CloudSQLEngineUnspecified CloudSQLEngine = "ENGINE_UNSPECIFIED"
CloudSQLEnginePostgreSQL CloudSQLEngine = "POSTGRESQL"
CloudSQLEngineMySQL CloudSQLEngine = "MYSQL"
)
// AgentContextReference reflects the proto definition for agent context.
type AgentContextReference struct {
ContextSetID string `json:"contextSetId,omitempty" yaml:"contextSetId,omitempty"`
}
// GenerationOptions reflects the proto definition for generation options.
type GenerationOptions struct {
GenerateQueryResult bool `json:"generateQueryResult" yaml:"generateQueryResult"`
GenerateNaturalLanguageAnswer bool `json:"generateNaturalLanguageAnswer" yaml:"generateNaturalLanguageAnswer"`
GenerateExplanation bool `json:"generateExplanation" yaml:"generateExplanation"`
GenerateDisambiguationQuestion bool `json:"generateDisambiguationQuestion" yaml:"generateDisambiguationQuestion"`
}

View File

@@ -62,11 +62,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -83,35 +78,16 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
urlParameter := parameters.NewStringParameter(pageURLKey, "The full URL of the FHIR page to fetch. This would be the value of `Bundle.entry.link.url` field within the response returned from FHIR search or FHIR patient everything operations.")
params := parameters.Parameters{urlParameter}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedFHIRStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -121,14 +97,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -136,13 +107,18 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
url, ok := params.AsMap()[pageURLKey].(string)
if !ok {
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey)
}
var httpClient *http.Client
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
@@ -150,7 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tokenStr})
httpClient = oauth2.NewClient(ctx, ts)
} else {
// The t.Service object holds a client with the default credentials.
// The source.Service() object holds a client with the default credentials.
// However, the client is not exported, so we have to create a new one.
var err error
httpClient, err = google.DefaultClient(ctx, healthcare.CloudHealthcareScope)
@@ -201,10 +177,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -62,11 +62,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -92,7 +87,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
idParameter := parameters.NewStringParameter(patientIDKey, "The ID of the patient FHIR resource for which the information is required")
@@ -106,17 +101,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedFHIRStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -126,15 +114,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -142,7 +124,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
if err != nil {
return nil, err
}
@@ -151,20 +138,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey)
}
svc := t.Service
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", t.Project, t.Region, t.Dataset, storeID, patientID)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", source.Project(), source.Region(), source.DatasetID(), storeID, patientID)
var opts []googleapi.CallOption
if val, ok := params.AsMap()[typeFilterKey]; ok {
types, ok := val.([]any)
@@ -225,10 +212,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -78,11 +78,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -108,7 +103,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
params := parameters.Parameters{
@@ -140,17 +135,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedFHIRStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -160,15 +148,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -176,19 +158,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := t.Service
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
@@ -261,7 +248,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
opts = append(opts, googleapi.QueryParameter("_summary", "text"))
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.SearchType(name, "Patient", &healthcare.SearchResourcesRequest{ResourceType: "Patient"}).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to search patient resources: %w", err)
@@ -298,10 +285,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -51,11 +51,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -72,33 +67,15 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
params := parameters.Parameters{}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -108,13 +85,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -122,22 +95,26 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
svc := t.Service
var err error
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset)
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
@@ -161,10 +138,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -53,11 +53,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
params := parameters.Parameters{}
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedDICOMStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := t.Service
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.DicomStores.Get(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get DICOM store %q: %w", storeName, err)
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -53,11 +53,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
params := parameters.Parameters{}
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedDICOMStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := t.Service
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.DicomStores.GetDICOMStoreMetrics(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get metrics for DICOM store %q: %w", storeName, err)
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

Some files were not shown because too many files have changed in this diff Show More