mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-18 11:02:26 -05:00
Compare commits
17 Commits
lsc-177138
...
anti-promp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92ad186cd7 | ||
|
|
967a72da11 | ||
|
|
7daa4111f4 | ||
|
|
18885f6433 | ||
|
|
21d676ed58 | ||
|
|
1c353a3c8e | ||
|
|
a02ca45ba3 | ||
|
|
8217d1424d | ||
|
|
f520b4ed8a | ||
|
|
80315a0ebd | ||
|
|
5788605818 | ||
|
|
0641da0353 | ||
|
|
c9b775d38e | ||
|
|
b25fd83bb9 | ||
|
|
ec844d5087 | ||
|
|
5978a746fd | ||
|
|
5b79f712e4 |
@@ -338,7 +338,7 @@ steps:
|
||||
.ci/test_with_coverage.sh \
|
||||
"Spanner" \
|
||||
spanner \
|
||||
spanner
|
||||
spanner || echo "Integration tests failed." # ignore test failures
|
||||
|
||||
- id: "neo4j"
|
||||
name: golang:1
|
||||
|
||||
18
.github/renovate.json5
vendored
18
.github/renovate.json5
vendored
@@ -24,5 +24,23 @@
|
||||
],
|
||||
pinDigests: true,
|
||||
},
|
||||
{
|
||||
groupName: 'Go',
|
||||
matchManagers: [
|
||||
'gomod',
|
||||
],
|
||||
},
|
||||
{
|
||||
groupName: 'Node',
|
||||
matchManagers: [
|
||||
'npm',
|
||||
],
|
||||
},
|
||||
{
|
||||
groupName: 'Pip',
|
||||
matchManagers: [
|
||||
'pip_requirements',
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
@@ -51,6 +51,10 @@ ignoreFiles = ["quickstart/shared", "quickstart/python", "quickstart/js", "quick
|
||||
# Add a new version block here before every release
|
||||
# The order of versions in this file is mirrored into the dropdown
|
||||
|
||||
[[params.versions]]
|
||||
version = "v0.24.0"
|
||||
url = "https://googleapis.github.io/genai-toolbox/v0.24.0/"
|
||||
|
||||
[[params.versions]]
|
||||
version = "v0.23.0"
|
||||
url = "https://googleapis.github.io/genai-toolbox/v0.23.0/"
|
||||
|
||||
17
CHANGELOG.md
17
CHANGELOG.md
@@ -1,5 +1,22 @@
|
||||
# Changelog
|
||||
|
||||
## [0.24.0](https://github.com/googleapis/genai-toolbox/compare/v0.23.0...v0.24.0) (2025-12-19)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* **sources/cloud-gemini-data-analytics:** Add the Gemini Data Analytics (GDA) integration for DB NL2SQL conversion to Toolbox ([#2181](https://github.com/googleapis/genai-toolbox/issues/2181)) ([aa270b2](https://github.com/googleapis/genai-toolbox/commit/aa270b2630da2e3d618db804ca95550445367dbc))
|
||||
* **source/cloudsqlmysql:** Add support for IAM authentication in Cloud SQL MySQL source ([#2050](https://github.com/googleapis/genai-toolbox/issues/2050)) ([af3d3c5](https://github.com/googleapis/genai-toolbox/commit/af3d3c52044bea17781b89ce4ab71ff0f874ac20))
|
||||
* **sources/oracle:** Add Oracle OCI and Wallet support ([#1945](https://github.com/googleapis/genai-toolbox/issues/1945)) ([8ea39ec](https://github.com/googleapis/genai-toolbox/commit/8ea39ec32fbbaa97939c626fec8c5d86040ed464))
|
||||
* Support combining prebuilt and custom tool configurations ([#2188](https://github.com/googleapis/genai-toolbox/issues/2188)) ([5788605](https://github.com/googleapis/genai-toolbox/commit/57886058188aa5d2a51d5846a98bc6d8a650edd1))
|
||||
* **tools/mysql-get-query-plan:** Add new `mysql-get-query-plan` tool for MySQL source ([#2123](https://github.com/googleapis/genai-toolbox/issues/2123)) ([0641da0](https://github.com/googleapis/genai-toolbox/commit/0641da0353857317113b2169e547ca69603ddfde))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* **spanner:** Move list graphs validation to runtime ([#2154](https://github.com/googleapis/genai-toolbox/issues/2154)) ([914b3ee](https://github.com/googleapis/genai-toolbox/commit/914b3eefda40a650efe552d245369e007277dab5))
|
||||
|
||||
|
||||
## [0.23.0](https://github.com/googleapis/genai-toolbox/compare/v0.22.0...v0.23.0) (2025-12-11)
|
||||
|
||||
|
||||
|
||||
14
README.md
14
README.md
@@ -140,7 +140,7 @@ To install Toolbox as a binary:
|
||||
>
|
||||
> ```sh
|
||||
> # see releases page for other versions
|
||||
> export VERSION=0.23.0
|
||||
> export VERSION=0.24.0
|
||||
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox
|
||||
> chmod +x toolbox
|
||||
> ```
|
||||
@@ -153,7 +153,7 @@ To install Toolbox as a binary:
|
||||
>
|
||||
> ```sh
|
||||
> # see releases page for other versions
|
||||
> export VERSION=0.23.0
|
||||
> export VERSION=0.24.0
|
||||
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/arm64/toolbox
|
||||
> chmod +x toolbox
|
||||
> ```
|
||||
@@ -166,7 +166,7 @@ To install Toolbox as a binary:
|
||||
>
|
||||
> ```sh
|
||||
> # see releases page for other versions
|
||||
> export VERSION=0.23.0
|
||||
> export VERSION=0.24.0
|
||||
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/amd64/toolbox
|
||||
> chmod +x toolbox
|
||||
> ```
|
||||
@@ -179,7 +179,7 @@ To install Toolbox as a binary:
|
||||
>
|
||||
> ```cmd
|
||||
> :: see releases page for other versions
|
||||
> set VERSION=0.23.0
|
||||
> set VERSION=0.24.0
|
||||
> curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe"
|
||||
> ```
|
||||
>
|
||||
@@ -191,7 +191,7 @@ To install Toolbox as a binary:
|
||||
>
|
||||
> ```powershell
|
||||
> # see releases page for other versions
|
||||
> $VERSION = "0.23.0"
|
||||
> $VERSION = "0.24.0"
|
||||
> curl.exe -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v$VERSION/windows/amd64/toolbox.exe"
|
||||
> ```
|
||||
>
|
||||
@@ -204,7 +204,7 @@ You can also install Toolbox as a container:
|
||||
|
||||
```sh
|
||||
# see releases page for other versions
|
||||
export VERSION=0.23.0
|
||||
export VERSION=0.24.0
|
||||
docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION
|
||||
```
|
||||
|
||||
@@ -228,7 +228,7 @@ To install from source, ensure you have the latest version of
|
||||
[Go installed](https://go.dev/doc/install), and then run the following command:
|
||||
|
||||
```sh
|
||||
go install github.com/googleapis/genai-toolbox@v0.23.0
|
||||
go install github.com/googleapis/genai-toolbox@v0.24.0
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
166
cmd/root.go
166
cmd/root.go
@@ -169,6 +169,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqllisttables"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlsql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlgetqueryplan"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllistactivequeries"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttablefragmentation"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttables"
|
||||
@@ -234,6 +235,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/bigtable"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cassandra"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/clickhouse"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
@@ -354,12 +356,12 @@ func NewCommand(opts ...Option) *Command {
|
||||
flags.StringVarP(&cmd.cfg.Address, "address", "a", "127.0.0.1", "Address of the interface the server will listen on.")
|
||||
flags.IntVarP(&cmd.cfg.Port, "port", "p", 5000, "Port the server will listen on.")
|
||||
|
||||
flags.StringVar(&cmd.tools_file, "tools_file", "", "File path specifying the tool configuration. Cannot be used with --prebuilt.")
|
||||
flags.StringVar(&cmd.tools_file, "tools_file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.")
|
||||
// deprecate tools_file
|
||||
_ = flags.MarkDeprecated("tools_file", "please use --tools-file instead")
|
||||
flags.StringVar(&cmd.tools_file, "tools-file", "", "File path specifying the tool configuration. Cannot be used with --prebuilt, --tools-files, or --tools-folder.")
|
||||
flags.StringSliceVar(&cmd.tools_files, "tools-files", []string{}, "Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --prebuilt, --tools-file, or --tools-folder.")
|
||||
flags.StringVar(&cmd.tools_folder, "tools-folder", "", "Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --prebuilt, --tools-file, or --tools-files.")
|
||||
flags.StringVar(&cmd.tools_file, "tools-file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.")
|
||||
flags.StringSliceVar(&cmd.tools_files, "tools-files", []string{}, "Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --tools-file, or --tools-folder.")
|
||||
flags.StringVar(&cmd.tools_folder, "tools-folder", "", "Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --tools-file, or --tools-files.")
|
||||
flags.Var(&cmd.cfg.LogLevel, "log-level", "Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'.")
|
||||
flags.Var(&cmd.cfg.LoggingFormat, "logging-format", "Specify logging format to use. Allowed: 'standard' or 'JSON'.")
|
||||
flags.BoolVar(&cmd.cfg.TelemetryGCP, "telemetry-gcp", false, "Enable exporting directly to Google Cloud Monitoring.")
|
||||
@@ -367,7 +369,7 @@ func NewCommand(opts ...Option) *Command {
|
||||
flags.StringVar(&cmd.cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.")
|
||||
// Fetch prebuilt tools sources to customize the help description
|
||||
prebuiltHelp := fmt.Sprintf(
|
||||
"Use a prebuilt tool configuration by source type. Cannot be used with --tools-file. Allowed: '%s'.",
|
||||
"Use a prebuilt tool configuration by source type. Allowed: '%s'.",
|
||||
strings.Join(prebuiltconfigs.GetPrebuiltSources(), "', '"),
|
||||
)
|
||||
flags.StringVar(&cmd.prebuiltConfig, "prebuilt", "", prebuiltHelp)
|
||||
@@ -461,6 +463,9 @@ func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) {
|
||||
if _, exists := merged.AuthSources[name]; exists {
|
||||
conflicts = append(conflicts, fmt.Sprintf("authSource '%s' (file #%d)", name, fileIndex+1))
|
||||
} else {
|
||||
if merged.AuthSources == nil {
|
||||
merged.AuthSources = make(server.AuthServiceConfigs)
|
||||
}
|
||||
merged.AuthSources[name] = authSource
|
||||
}
|
||||
}
|
||||
@@ -837,16 +842,10 @@ func run(cmd *Command) error {
|
||||
}
|
||||
}()
|
||||
|
||||
var toolsFile ToolsFile
|
||||
var allToolsFiles []ToolsFile
|
||||
|
||||
// Load Prebuilt Configuration
|
||||
if cmd.prebuiltConfig != "" {
|
||||
// Make sure --prebuilt and --tools-file/--tools-files/--tools-folder flags are mutually exclusive
|
||||
if cmd.tools_file != "" || len(cmd.tools_files) > 0 || cmd.tools_folder != "" {
|
||||
errMsg := fmt.Errorf("--prebuilt and --tools-file/--tools-files/--tools-folder flags cannot be used simultaneously")
|
||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
// Use prebuilt tools
|
||||
buf, err := prebuiltconfigs.Get(cmd.prebuiltConfig)
|
||||
if err != nil {
|
||||
cmd.logger.ErrorContext(ctx, err.Error())
|
||||
@@ -857,72 +856,96 @@ func run(cmd *Command) error {
|
||||
// Append prebuilt.source to Version string for the User Agent
|
||||
cmd.cfg.Version += "+prebuilt." + cmd.prebuiltConfig
|
||||
|
||||
toolsFile, err = parseToolsFile(ctx, buf)
|
||||
parsed, err := parseToolsFile(ctx, buf)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("unable to parse prebuilt tool configuration: %w", err)
|
||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
} else if len(cmd.tools_files) > 0 {
|
||||
// Make sure --tools-file, --tools-files, and --tools-folder flags are mutually exclusive
|
||||
if cmd.tools_file != "" || cmd.tools_folder != "" {
|
||||
errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously")
|
||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
// Use multiple tools files
|
||||
cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging %d tool configuration files", len(cmd.tools_files)))
|
||||
var err error
|
||||
toolsFile, err = loadAndMergeToolsFiles(ctx, cmd.tools_files)
|
||||
if err != nil {
|
||||
cmd.logger.ErrorContext(ctx, err.Error())
|
||||
return err
|
||||
}
|
||||
} else if cmd.tools_folder != "" {
|
||||
// Make sure --tools-folder and other flags are mutually exclusive
|
||||
if cmd.tools_file != "" || len(cmd.tools_files) > 0 {
|
||||
errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously")
|
||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
// Use tools folder
|
||||
cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging all YAML files from directory: %s", cmd.tools_folder))
|
||||
var err error
|
||||
toolsFile, err = loadAndMergeToolsFolder(ctx, cmd.tools_folder)
|
||||
if err != nil {
|
||||
cmd.logger.ErrorContext(ctx, err.Error())
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Set default value of tools-file flag to tools.yaml
|
||||
if cmd.tools_file == "" {
|
||||
cmd.tools_file = "tools.yaml"
|
||||
}
|
||||
|
||||
// Read single tool file contents
|
||||
buf, err := os.ReadFile(cmd.tools_file)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("unable to read tool file at %q: %w", cmd.tools_file, err)
|
||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
toolsFile, err = parseToolsFile(ctx, buf)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("unable to parse tool file at %q: %w", cmd.tools_file, err)
|
||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
allToolsFiles = append(allToolsFiles, parsed)
|
||||
}
|
||||
|
||||
cmd.cfg.SourceConfigs, cmd.cfg.AuthServiceConfigs, cmd.cfg.ToolConfigs, cmd.cfg.ToolsetConfigs, cmd.cfg.PromptConfigs = toolsFile.Sources, toolsFile.AuthServices, toolsFile.Tools, toolsFile.Toolsets, toolsFile.Prompts
|
||||
// Determine if Custom Files should be loaded
|
||||
// Check for explicit custom flags
|
||||
isCustomConfigured := cmd.tools_file != "" || len(cmd.tools_files) > 0 || cmd.tools_folder != ""
|
||||
|
||||
authSourceConfigs := toolsFile.AuthSources
|
||||
// Determine if default 'tools.yaml' should be used (No prebuilt AND No custom flags)
|
||||
useDefaultToolsFile := cmd.prebuiltConfig == "" && !isCustomConfigured
|
||||
|
||||
if useDefaultToolsFile {
|
||||
cmd.tools_file = "tools.yaml"
|
||||
isCustomConfigured = true
|
||||
}
|
||||
|
||||
// Load Custom Configurations
|
||||
if isCustomConfigured {
|
||||
// Enforce exclusivity among custom flags (tools-file vs tools-files vs tools-folder)
|
||||
if (cmd.tools_file != "" && len(cmd.tools_files) > 0) ||
|
||||
(cmd.tools_file != "" && cmd.tools_folder != "") ||
|
||||
(len(cmd.tools_files) > 0 && cmd.tools_folder != "") {
|
||||
errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously")
|
||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
var customTools ToolsFile
|
||||
var err error
|
||||
|
||||
if len(cmd.tools_files) > 0 {
|
||||
// Use tools-files
|
||||
cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging %d tool configuration files", len(cmd.tools_files)))
|
||||
customTools, err = loadAndMergeToolsFiles(ctx, cmd.tools_files)
|
||||
} else if cmd.tools_folder != "" {
|
||||
// Use tools-folder
|
||||
cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging all YAML files from directory: %s", cmd.tools_folder))
|
||||
customTools, err = loadAndMergeToolsFolder(ctx, cmd.tools_folder)
|
||||
} else {
|
||||
// Use single file (tools-file or default `tools.yaml`)
|
||||
buf, readFileErr := os.ReadFile(cmd.tools_file)
|
||||
if readFileErr != nil {
|
||||
errMsg := fmt.Errorf("unable to read tool file at %q: %w", cmd.tools_file, readFileErr)
|
||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
customTools, err = parseToolsFile(ctx, buf)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to parse tool file at %q: %w", cmd.tools_file, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
cmd.logger.ErrorContext(ctx, err.Error())
|
||||
return err
|
||||
}
|
||||
allToolsFiles = append(allToolsFiles, customTools)
|
||||
}
|
||||
|
||||
// Merge Everything
|
||||
// This will error if custom tools collide with prebuilt tools
|
||||
finalToolsFile, err := mergeToolsFiles(allToolsFiles...)
|
||||
if err != nil {
|
||||
cmd.logger.ErrorContext(ctx, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.cfg.SourceConfigs = finalToolsFile.Sources
|
||||
cmd.cfg.AuthServiceConfigs = finalToolsFile.AuthServices
|
||||
cmd.cfg.ToolConfigs = finalToolsFile.Tools
|
||||
cmd.cfg.ToolsetConfigs = finalToolsFile.Toolsets
|
||||
cmd.cfg.PromptConfigs = finalToolsFile.Prompts
|
||||
|
||||
authSourceConfigs := finalToolsFile.AuthSources
|
||||
if authSourceConfigs != nil {
|
||||
cmd.logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` instead")
|
||||
cmd.cfg.AuthServiceConfigs = authSourceConfigs
|
||||
|
||||
for k, v := range authSourceConfigs {
|
||||
if _, exists := cmd.cfg.AuthServiceConfigs[k]; exists {
|
||||
errMsg := fmt.Errorf("resource conflict detected: authSource '%s' has the same name as an existing authService. Please rename your authSource", k)
|
||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
cmd.cfg.AuthServiceConfigs[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
instrumentation, err := telemetry.CreateTelemetryInstrumentation(versionString)
|
||||
@@ -973,9 +996,8 @@ func run(cmd *Command) error {
|
||||
}()
|
||||
}
|
||||
|
||||
watchDirs, watchedFiles := resolveWatcherInputs(cmd.tools_file, cmd.tools_files, cmd.tools_folder)
|
||||
|
||||
if !cmd.cfg.DisableReload {
|
||||
if isCustomConfigured && !cmd.cfg.DisableReload {
|
||||
watchDirs, watchedFiles := resolveWatcherInputs(cmd.tools_file, cmd.tools_files, cmd.tools_folder)
|
||||
// start watching the file(s) or folder for changes to trigger dynamic reloading
|
||||
go watchChanges(ctx, watchDirs, watchedFiles, s)
|
||||
}
|
||||
|
||||
245
cmd/root_test.go
245
cmd/root_test.go
@@ -92,6 +92,21 @@ func invokeCommand(args []string) (*Command, string, error) {
|
||||
return c, buf.String(), err
|
||||
}
|
||||
|
||||
// invokeCommandWithContext executes the command with a context and returns the captured output.
|
||||
func invokeCommandWithContext(ctx context.Context, args []string) (*Command, string, error) {
|
||||
// Capture output using a buffer
|
||||
buf := new(bytes.Buffer)
|
||||
c := NewCommand(WithStreams(buf, buf))
|
||||
|
||||
c.SetArgs(args)
|
||||
c.SilenceUsage = true
|
||||
c.SilenceErrors = true
|
||||
c.SetContext(ctx)
|
||||
|
||||
err := c.Execute()
|
||||
return c, buf.String(), err
|
||||
}
|
||||
|
||||
func TestVersion(t *testing.T) {
|
||||
data, err := os.ReadFile("version.txt")
|
||||
if err != nil {
|
||||
@@ -1755,11 +1770,6 @@ func TestMutuallyExclusiveFlags(t *testing.T) {
|
||||
args []string
|
||||
errString string
|
||||
}{
|
||||
{
|
||||
desc: "--prebuilt and --tools-file",
|
||||
args: []string{"--prebuilt", "alloydb", "--tools-file", "my.yaml"},
|
||||
errString: "--prebuilt and --tools-file/--tools-files/--tools-folder flags cannot be used simultaneously",
|
||||
},
|
||||
{
|
||||
desc: "--tools-file and --tools-files",
|
||||
args: []string{"--tools-file", "my.yaml", "--tools-files", "a.yaml,b.yaml"},
|
||||
@@ -1902,3 +1912,228 @@ func TestMergeToolsFiles(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
func TestPrebuiltAndCustomTools(t *testing.T) {
|
||||
t.Setenv("SQLITE_DATABASE", "test.db")
|
||||
// Setup custom tools file
|
||||
customContent := `
|
||||
tools:
|
||||
custom_tool:
|
||||
kind: http
|
||||
source: my-http
|
||||
method: GET
|
||||
path: /
|
||||
description: "A custom tool for testing"
|
||||
sources:
|
||||
my-http:
|
||||
kind: http
|
||||
baseUrl: http://example.com
|
||||
`
|
||||
customFile := filepath.Join(t.TempDir(), "custom.yaml")
|
||||
if err := os.WriteFile(customFile, []byte(customContent), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Tool Conflict File
|
||||
// SQLite prebuilt has a tool named 'list_tables'
|
||||
toolConflictContent := `
|
||||
tools:
|
||||
list_tables:
|
||||
kind: http
|
||||
source: my-http
|
||||
method: GET
|
||||
path: /
|
||||
description: "Conflicting tool"
|
||||
sources:
|
||||
my-http:
|
||||
kind: http
|
||||
baseUrl: http://example.com
|
||||
`
|
||||
toolConflictFile := filepath.Join(t.TempDir(), "tool_conflict.yaml")
|
||||
if err := os.WriteFile(toolConflictFile, []byte(toolConflictContent), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Source Conflict File
|
||||
// SQLite prebuilt has a source named 'sqlite-source'
|
||||
sourceConflictContent := `
|
||||
sources:
|
||||
sqlite-source:
|
||||
kind: http
|
||||
baseUrl: http://example.com
|
||||
tools:
|
||||
dummy_tool:
|
||||
kind: http
|
||||
source: sqlite-source
|
||||
method: GET
|
||||
path: /
|
||||
description: "Dummy"
|
||||
`
|
||||
sourceConflictFile := filepath.Join(t.TempDir(), "source_conflict.yaml")
|
||||
if err := os.WriteFile(sourceConflictFile, []byte(sourceConflictContent), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Toolset Conflict File
|
||||
// SQLite prebuilt has a toolset named 'sqlite_database_tools'
|
||||
toolsetConflictContent := `
|
||||
sources:
|
||||
dummy-src:
|
||||
kind: http
|
||||
baseUrl: http://example.com
|
||||
tools:
|
||||
dummy_tool:
|
||||
kind: http
|
||||
source: dummy-src
|
||||
method: GET
|
||||
path: /
|
||||
description: "Dummy"
|
||||
toolsets:
|
||||
sqlite_database_tools:
|
||||
- dummy_tool
|
||||
`
|
||||
toolsetConflictFile := filepath.Join(t.TempDir(), "toolset_conflict.yaml")
|
||||
if err := os.WriteFile(toolsetConflictFile, []byte(toolsetConflictContent), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
//Legacy Auth File
|
||||
authContent := `
|
||||
authSources:
|
||||
legacy-auth:
|
||||
kind: google
|
||||
clientId: "test-client-id"
|
||||
`
|
||||
authFile := filepath.Join(t.TempDir(), "auth.yaml")
|
||||
if err := os.WriteFile(authFile, []byte(authContent), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
args []string
|
||||
wantErr bool
|
||||
errString string
|
||||
cfgCheck func(server.ServerConfig) error
|
||||
}{
|
||||
{
|
||||
desc: "success mixed",
|
||||
args: []string{"--prebuilt", "sqlite", "--tools-file", customFile},
|
||||
wantErr: false,
|
||||
cfgCheck: func(cfg server.ServerConfig) error {
|
||||
if _, ok := cfg.ToolConfigs["custom_tool"]; !ok {
|
||||
return fmt.Errorf("custom tool not found")
|
||||
}
|
||||
if _, ok := cfg.ToolConfigs["list_tables"]; !ok {
|
||||
return fmt.Errorf("prebuilt tool 'list_tables' not found")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool conflict error",
|
||||
args: []string{"--prebuilt", "sqlite", "--tools-file", toolConflictFile},
|
||||
wantErr: true,
|
||||
errString: "resource conflicts detected",
|
||||
},
|
||||
{
|
||||
desc: "source conflict error",
|
||||
args: []string{"--prebuilt", "sqlite", "--tools-file", sourceConflictFile},
|
||||
wantErr: true,
|
||||
errString: "resource conflicts detected",
|
||||
},
|
||||
{
|
||||
desc: "toolset conflict error",
|
||||
args: []string{"--prebuilt", "sqlite", "--tools-file", toolsetConflictFile},
|
||||
wantErr: true,
|
||||
errString: "resource conflicts detected",
|
||||
},
|
||||
{
|
||||
desc: "legacy auth additive",
|
||||
args: []string{"--prebuilt", "sqlite", "--tools-file", authFile},
|
||||
wantErr: false,
|
||||
cfgCheck: func(cfg server.ServerConfig) error {
|
||||
if _, ok := cfg.AuthServiceConfigs["legacy-auth"]; !ok {
|
||||
return fmt.Errorf("legacy auth source not merged into auth services")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
cmd, output, err := invokeCommandWithContext(ctx, tc.args)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Fatalf("expected an error but got none")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.errString) {
|
||||
t.Errorf("expected error message to contain %q, but got %q", tc.errString, err.Error())
|
||||
}
|
||||
} else {
|
||||
if err != nil && err != context.DeadlineExceeded && err != context.Canceled {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(output, "Server ready to serve!") {
|
||||
t.Errorf("server did not start successfully (no ready message found). Output:\n%s", output)
|
||||
}
|
||||
if tc.cfgCheck != nil {
|
||||
if err := tc.cfgCheck(cmd.cfg); err != nil {
|
||||
t.Errorf("config check failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultToolsFileBehavior(t *testing.T) {
|
||||
t.Setenv("SQLITE_DATABASE", "test.db")
|
||||
testCases := []struct {
|
||||
desc string
|
||||
args []string
|
||||
expectRun bool
|
||||
errString string
|
||||
}{
|
||||
{
|
||||
desc: "no flags (defaults to tools.yaml)",
|
||||
args: []string{},
|
||||
expectRun: false,
|
||||
errString: "tools.yaml", // Expect error because tools.yaml doesn't exist in test env
|
||||
},
|
||||
{
|
||||
desc: "prebuilt only (skips tools.yaml)",
|
||||
args: []string{"--prebuilt", "sqlite"},
|
||||
expectRun: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
_, output, err := invokeCommandWithContext(ctx, tc.args)
|
||||
|
||||
if tc.expectRun {
|
||||
if err != nil && err != context.DeadlineExceeded && err != context.Canceled {
|
||||
t.Fatalf("expected server start, got error: %v", err)
|
||||
}
|
||||
// Verify it actually started
|
||||
if !strings.Contains(output, "Server ready to serve!") {
|
||||
t.Errorf("server did not start successfully (no ready message found). Output:\n%s", output)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error reading default file, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.errString) {
|
||||
t.Errorf("expected error message to contain %q, but got %q", tc.errString, err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1 +1 @@
|
||||
0.23.0
|
||||
0.24.0
|
||||
|
||||
@@ -234,7 +234,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"version = \"0.23.0\" # x-release-please-version\n",
|
||||
"version = \"0.24.0\" # x-release-please-version\n",
|
||||
"! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
|
||||
"\n",
|
||||
"# Make the binary executable\n",
|
||||
|
||||
@@ -103,7 +103,7 @@ To install Toolbox as a binary on Linux (AMD64):
|
||||
|
||||
```sh
|
||||
# see releases page for other versions
|
||||
export VERSION=0.23.0
|
||||
export VERSION=0.24.0
|
||||
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox
|
||||
chmod +x toolbox
|
||||
```
|
||||
@@ -114,7 +114,7 @@ To install Toolbox as a binary on macOS (Apple Silicon):
|
||||
|
||||
```sh
|
||||
# see releases page for other versions
|
||||
export VERSION=0.23.0
|
||||
export VERSION=0.24.0
|
||||
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/arm64/toolbox
|
||||
chmod +x toolbox
|
||||
```
|
||||
@@ -125,7 +125,7 @@ To install Toolbox as a binary on macOS (Intel):
|
||||
|
||||
```sh
|
||||
# see releases page for other versions
|
||||
export VERSION=0.23.0
|
||||
export VERSION=0.24.0
|
||||
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/amd64/toolbox
|
||||
chmod +x toolbox
|
||||
```
|
||||
@@ -136,7 +136,7 @@ To install Toolbox as a binary on Windows (Command Prompt):
|
||||
|
||||
```cmd
|
||||
:: see releases page for other versions
|
||||
set VERSION=0.23.0
|
||||
set VERSION=0.24.0
|
||||
curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe"
|
||||
```
|
||||
|
||||
@@ -146,7 +146,7 @@ To install Toolbox as a binary on Windows (PowerShell):
|
||||
|
||||
```powershell
|
||||
# see releases page for other versions
|
||||
$VERSION = "0.23.0"
|
||||
$VERSION = "0.24.0"
|
||||
curl.exe -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v$VERSION/windows/amd64/toolbox.exe"
|
||||
```
|
||||
|
||||
@@ -158,7 +158,7 @@ You can also install Toolbox as a container:
|
||||
|
||||
```sh
|
||||
# see releases page for other versions
|
||||
export VERSION=0.23.0
|
||||
export VERSION=0.24.0
|
||||
docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION
|
||||
```
|
||||
|
||||
@@ -177,7 +177,7 @@ To install from source, ensure you have the latest version of
|
||||
[Go installed](https://go.dev/doc/install), and then run the following command:
|
||||
|
||||
```sh
|
||||
go install github.com/googleapis/genai-toolbox@v0.23.0
|
||||
go install github.com/googleapis/genai-toolbox@v0.24.0
|
||||
```
|
||||
|
||||
{{% /tab %}}
|
||||
|
||||
@@ -105,7 +105,7 @@ In this section, we will download Toolbox, configure our tools in a
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
239
docs/en/getting-started/prompts_quickstart_antigravity/_index.md
Normal file
239
docs/en/getting-started/prompts_quickstart_antigravity/_index.md
Normal file
@@ -0,0 +1,239 @@
|
||||
---
|
||||
title: "Prompts using Antigravity"
|
||||
type: docs
|
||||
weight: 5
|
||||
description: >
|
||||
How to get started using Toolbox prompts locally with PostgreSQL and [Antigravity](https://antigravity.google/).
|
||||
---
|
||||
|
||||
## Before you begin
|
||||
|
||||
This guide assumes you have already done the following:
|
||||
|
||||
1. Installed [PostgreSQL 16+ and the `psql` client][install-postgres].
|
||||
|
||||
[install-postgres]: https://www.postgresql.org/download/
|
||||
|
||||
## Step 1: Set up your database
|
||||
|
||||
In this section, we will create a database, insert some data that needs to be
|
||||
accessed by our agent, and create a database user for Toolbox to connect with.
|
||||
|
||||
1. Connect to postgres using the `psql` command:
|
||||
|
||||
```bash
|
||||
psql -h 127.0.0.1 -U postgres
|
||||
```
|
||||
|
||||
Here, `postgres` denotes the default postgres superuser.
|
||||
|
||||
{{< notice info >}}
|
||||
|
||||
#### **Having trouble connecting?**
|
||||
|
||||
- **Password Prompt:** If you are prompted for a password for the `postgres`
|
||||
user and do not know it (or a blank password doesn't work), your PostgreSQL
|
||||
installation might require a password or a different authentication method.
|
||||
- **`FATAL: role "postgres" does not exist`:** This error means the default
|
||||
`postgres` superuser role isn't available under that name on your system.
|
||||
- **`Connection refused`:** Ensure your PostgreSQL server is actually running.
|
||||
You can typically check with `sudo systemctl status postgresql` and start it
|
||||
with `sudo systemctl start postgresql` on Linux systems.
|
||||
|
||||
<br/>
|
||||
|
||||
#### **Common Solution**
|
||||
|
||||
For password issues or if the `postgres` role seems inaccessible directly, try
|
||||
switching to the `postgres` operating system user first. This user often has
|
||||
permission to connect without a password for local connections (this is called
|
||||
peer authentication).
|
||||
|
||||
```bash
|
||||
sudo -i -u postgres
|
||||
psql -h 127.0.0.1
|
||||
```
|
||||
|
||||
Once you are in the `psql` shell using this method, you can proceed with the
|
||||
database creation steps below. Afterwards, type `\q` to exit `psql`, and then
|
||||
`exit` to return to your normal user shell.
|
||||
|
||||
If desired, once connected to `psql` as the `postgres` OS user, you can set a
|
||||
password for the `postgres` _database_ user using: `ALTER USER postgres WITH
|
||||
PASSWORD 'your_chosen_password';`. This would allow direct connection with `-U
|
||||
postgres` and a password next time.
|
||||
{{< /notice >}}
|
||||
|
||||
1. Create a new database and a new user:
|
||||
|
||||
{{< notice tip >}}
|
||||
For a real application, it's best to follow the principle of least permission
|
||||
and only grant the privileges your application needs.
|
||||
{{< /notice >}}
|
||||
|
||||
```sql
|
||||
CREATE USER toolbox_user WITH PASSWORD 'my-password';
|
||||
|
||||
CREATE DATABASE toolbox_db;
|
||||
GRANT ALL PRIVILEGES ON DATABASE toolbox_db TO toolbox_user;
|
||||
|
||||
ALTER DATABASE toolbox_db OWNER TO toolbox_user;
|
||||
```
|
||||
|
||||
1. End the database session:
|
||||
|
||||
```bash
|
||||
\q
|
||||
```
|
||||
|
||||
(If you used `sudo -i -u postgres` and then `psql`, remember you might also
|
||||
need to type `exit` after `\q` to leave the `postgres` user's shell
|
||||
session.)
|
||||
|
||||
1. Connect to your database with your new user:
|
||||
|
||||
```bash
|
||||
psql -h 127.0.0.1 -U toolbox_user -d toolbox_db
|
||||
```
|
||||
|
||||
1. Create the required tables using the following commands:
|
||||
|
||||
```sql
|
||||
CREATE TABLE users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username VARCHAR(50) NOT NULL,
|
||||
email VARCHAR(100) UNIQUE NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE restaurants (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name VARCHAR(100) NOT NULL,
|
||||
location VARCHAR(100)
|
||||
);
|
||||
|
||||
CREATE TABLE reviews (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id INT REFERENCES users(id),
|
||||
restaurant_id INT REFERENCES restaurants(id),
|
||||
rating INT CHECK (rating >= 1 AND rating <= 5),
|
||||
review_text TEXT,
|
||||
is_published BOOLEAN DEFAULT false,
|
||||
moderation_status VARCHAR(50) DEFAULT 'pending_manual_review',
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
```
|
||||
|
||||
1. Insert dummy data into the tables.
|
||||
|
||||
```sql
|
||||
INSERT INTO users (id, username, email) VALUES
|
||||
(123, 'jane_d', 'jane.d@example.com'),
|
||||
(124, 'john_s', 'john.s@example.com'),
|
||||
(125, 'sam_b', 'sam.b@example.com');
|
||||
|
||||
INSERT INTO restaurants (id, name, location) VALUES
|
||||
(455, 'Pizza Palace', '123 Main St'),
|
||||
(456, 'The Corner Bistro', '456 Oak Ave'),
|
||||
(457, 'Sushi Spot', '789 Pine Ln');
|
||||
|
||||
INSERT INTO reviews (user_id, restaurant_id, rating, review_text, is_published, moderation_status) VALUES
|
||||
(124, 455, 5, 'Best pizza in town! The crust was perfect.', true, 'approved'),
|
||||
(125, 457, 4, 'Great sushi, very fresh. A bit pricey but worth it.', true, 'approved'),
|
||||
(123, 457, 5, 'Absolutely loved the dragon roll. Will be back!', true, 'approved'),
|
||||
(123, 456, 4, 'The atmosphere was lovely and the food was great. My photo upload might have been weird though.', false, 'pending_manual_review'),
|
||||
(125, 456, 1, 'This review contains inappropriate language.', false, 'rejected');
|
||||
```
|
||||
|
||||
1. End the database session:
|
||||
|
||||
```bash
|
||||
\q
|
||||
```
|
||||
|
||||
## Step 2: Configure Toolbox
|
||||
|
||||
Create a file named `tools.yaml`. This file defines the database connection, the
|
||||
SQL tools available, and the prompts the agents will use.
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-foodiefind-db:
|
||||
kind: postgres
|
||||
host: 127.0.0.1
|
||||
port: 5432
|
||||
database: toolbox_db
|
||||
user: toolbox_user
|
||||
password: my-password
|
||||
tools:
|
||||
find_user_by_email:
|
||||
kind: postgres-sql
|
||||
source: my-foodiefind-db
|
||||
description: Find a user's ID by their email address.
|
||||
parameters:
|
||||
- name: email
|
||||
type: string
|
||||
description: The email address of the user to find.
|
||||
statement: SELECT id FROM users WHERE email = $1;
|
||||
find_restaurant_by_name:
|
||||
kind: postgres-sql
|
||||
source: my-foodiefind-db
|
||||
description: Find a restaurant's ID by its exact name.
|
||||
parameters:
|
||||
- name: name
|
||||
type: string
|
||||
description: The name of the restaurant to find.
|
||||
statement: SELECT id FROM restaurants WHERE name = $1;
|
||||
find_review_by_user_and_restaurant:
|
||||
kind: postgres-sql
|
||||
source: my-foodiefind-db
|
||||
description: Find the full record for a specific review using the user's ID and the restaurant's ID.
|
||||
parameters:
|
||||
- name: user_id
|
||||
type: integer
|
||||
description: The numerical ID of the user.
|
||||
- name: restaurant_id
|
||||
type: integer
|
||||
description: The numerical ID of the restaurant.
|
||||
statement: SELECT * FROM reviews WHERE user_id = $1 AND restaurant_id = $2;
|
||||
prompts:
|
||||
investigate_missing_review:
|
||||
description: "Investigates a user's missing review by finding the user, restaurant, and the review itself, then analyzing its status."
|
||||
arguments:
|
||||
- name: "user_email"
|
||||
description: "The email of the user who wrote the review."
|
||||
- name: "restaurant_name"
|
||||
description: "The name of the restaurant being reviewed."
|
||||
messages:
|
||||
- content: >-
|
||||
**Goal:** Find the review written by the user with email '{{.user_email}}' for the restaurant named '{{.restaurant_name}}' and understand its status.
|
||||
**Workflow:**
|
||||
1. Use the `find_user_by_email` tool with the email '{{.user_email}}' to get the `user_id`.
|
||||
2. Use the `find_restaurant_by_name` tool with the name '{{.restaurant_name}}' to get the `restaurant_id`.
|
||||
3. Use the `find_review_by_user_and_restaurant` tool with the `user_id` and `restaurant_id` you just found.
|
||||
4. Analyze the results from the final tool call. Examine the `is_published` and `moderation_status` fields and explain the review's status to the user in a clear, human-readable sentence.
|
||||
```
|
||||
|
||||
## Step 3: Connect to Antigravity
|
||||
|
||||
Configure the Antigravity to talk to your local Toolbox MCP server.
|
||||
|
||||
1. Click on "MCP Servers" in the Agent Manager.
|
||||

|
||||
|
||||
2. Search for and click on MCP Toolbox for Databases.
|
||||

|
||||
|
||||
3. Click the install button and enter the correct path to the `tools.yaml` file.
|
||||

|
||||
|
||||
4. You should be able to see your tools in the tools tab.
|
||||

|
||||
|
||||
5. Now you can type your query into the agent panel.
|
||||
|
||||
```
|
||||
Investigate the missing review from jane.d@example.com for The corner bistro
|
||||
```
|
||||
|
||||

|
||||
Binary file not shown.
|
After Width: | Height: | Size: 26 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 197 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 145 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 42 KiB |
BIN
docs/en/getting-started/prompts_quickstart_antigravity/tools.png
Normal file
BIN
docs/en/getting-started/prompts_quickstart_antigravity/tools.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 116 KiB |
@@ -13,7 +13,7 @@ In this section, we will download Toolbox, configure our tools in a
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
@@ -49,19 +49,19 @@ to expose your developer assistant tools to a Looker instance:
|
||||
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="linux/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="windows/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -45,19 +45,19 @@ instance:
|
||||
<!-- {x-release-please-start-version} -->
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="linux/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="windows/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -43,19 +43,19 @@ expose your developer assistant tools to a MySQL instance:
|
||||
<!-- {x-release-please-start-version} -->
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="linux/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="windows/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -44,19 +44,19 @@ expose your developer assistant tools to a Neo4j instance:
|
||||
<!-- {x-release-please-start-version} -->
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="linux/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="windows/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -56,19 +56,19 @@ Omni](https://cloud.google.com/alloydb/omni/current/docs/overview).
|
||||
<!-- {x-release-please-start-version} -->
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="linux/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="windows/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -43,19 +43,19 @@ to expose your developer assistant tools to a SQLite instance:
|
||||
<!-- {x-release-please-start-version} -->
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="linux/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="windows/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -16,14 +16,14 @@ description: >
|
||||
| | `--log-level` | Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'. | `info` |
|
||||
| | `--logging-format` | Specify logging format to use. Allowed: 'standard' or 'JSON'. | `standard` |
|
||||
| `-p` | `--port` | Port the server will listen on. | `5000` |
|
||||
| | `--prebuilt` | Use a prebuilt tool configuration by source type. Cannot be used with --tools-file. See [Prebuilt Tools Reference](prebuilt-tools.md) for allowed values. | |
|
||||
| | `--prebuilt` | Use a prebuilt tool configuration by source type. See [Prebuilt Tools Reference](prebuilt-tools.md) for allowed values. | |
|
||||
| | `--stdio` | Listens via MCP STDIO instead of acting as a remote HTTP server. | |
|
||||
| | `--telemetry-gcp` | Enable exporting directly to Google Cloud Monitoring. | |
|
||||
| | `--telemetry-otlp` | Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318') | |
|
||||
| | `--telemetry-service-name` | Sets the value of the service.name resource attribute for telemetry data. | `toolbox` |
|
||||
| | `--tools-file` | File path specifying the tool configuration. Cannot be used with --prebuilt, --tools-files, or --tools-folder. | |
|
||||
| | `--tools-files` | Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --prebuilt, --tools-file, or --tools-folder. | |
|
||||
| | `--tools-folder` | Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --prebuilt, --tools-file, or --tools-files. | |
|
||||
| | `--tools-file` | File path specifying the tool configuration. Cannot be used with --tools-files or --tools-folder. | |
|
||||
| | `--tools-files` | Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --tools-file or --tools-folder. | |
|
||||
| | `--tools-folder` | Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --tools-file or --tools-files. | |
|
||||
| | `--ui` | Launches the Toolbox UI web server. | |
|
||||
| | `--allowed-origins` | Specifies a list of origins permitted to access this server. | `*` |
|
||||
| `-v` | `--version` | version for toolbox | |
|
||||
@@ -46,6 +46,9 @@ description: >
|
||||
```bash
|
||||
# Basic server with custom port configuration
|
||||
./toolbox --tools-file "tools.yaml" --port 8080
|
||||
|
||||
# Server with prebuilt + custom tools configurations
|
||||
./toolbox --tools-file tools.yaml --prebuilt alloydb-postgres
|
||||
```
|
||||
|
||||
### Tool Configuration Sources
|
||||
@@ -72,8 +75,8 @@ The CLI supports multiple mutually exclusive ways to specify tool configurations
|
||||
|
||||
{{< notice tip >}}
|
||||
The CLI enforces mutual exclusivity between configuration source flags,
|
||||
preventing simultaneous use of `--prebuilt` with file-based options, and
|
||||
ensuring only one of `--tools-file`, `--tools-files`, or `--tools-folder` is
|
||||
preventing simultaneous use of the file-based options ensuring only one of
|
||||
`--tools-file`, `--tools-files`, or `--tools-folder` is
|
||||
used at a time.
|
||||
{{< /notice >}}
|
||||
|
||||
|
||||
@@ -13,6 +13,12 @@ allowing developers to interact with and take action on databases.
|
||||
See guides, [Connect from your IDE](../how-to/connect-ide/_index.md), for
|
||||
details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP.
|
||||
|
||||
{{< notice tip >}}
|
||||
You can now use `--prebuilt` along `--tools-file`, `--tools-files`, or
|
||||
`--tools-folder` to combine prebuilt configs with custom tools.
|
||||
See [Usage Examples](../reference/cli.md#examples).
|
||||
{{< /notice >}}
|
||||
|
||||
## AlloyDB Postgres
|
||||
|
||||
* `--prebuilt` value: `alloydb-postgres`
|
||||
|
||||
@@ -31,6 +31,9 @@ to a database by following these instructions][csql-mysql-quickstart].
|
||||
- [`mysql-list-active-queries`](../tools/mysql/mysql-list-active-queries.md)
|
||||
List active queries in Cloud SQL for MySQL.
|
||||
|
||||
- [`mysql-get-query-plan`](../tools/mysql/mysql-get-query-plan.md)
|
||||
Provide information about how MySQL executes a SQL statement (EXPLAIN).
|
||||
|
||||
- [`mysql-list-tables`](../tools/mysql/mysql-list-tables.md)
|
||||
List tables in a Cloud SQL for MySQL database.
|
||||
|
||||
|
||||
@@ -25,6 +25,9 @@ reliability, performance, and ease of use.
|
||||
- [`mysql-list-active-queries`](../tools/mysql/mysql-list-active-queries.md)
|
||||
List active queries in MySQL.
|
||||
|
||||
- [`mysql-get-query-plan`](../tools/mysql/mysql-get-query-plan.md)
|
||||
Provide information about how MySQL executes a SQL statement (EXPLAIN).
|
||||
|
||||
- [`mysql-list-tables`](../tools/mysql/mysql-list-tables.md)
|
||||
List tables in a MySQL database.
|
||||
|
||||
|
||||
39
docs/en/resources/tools/mysql/mysql-get-query-plan.md
Normal file
39
docs/en/resources/tools/mysql/mysql-get-query-plan.md
Normal file
@@ -0,0 +1,39 @@
|
||||
---
|
||||
title: "mysql-get-query-plan"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "mysql-get-query-plan" tool gets the execution plan for a SQL statement against a MySQL
|
||||
database.
|
||||
aliases:
|
||||
- /resources/tools/mysql-get-query-plan
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
A `mysql-get-query-plan` tool gets the execution plan for a SQL statement against a MySQL
|
||||
database. It's compatible with any of the following sources:
|
||||
|
||||
- [cloud-sql-mysql](../../sources/cloud-sql-mysql.md)
|
||||
- [mysql](../../sources/mysql.md)
|
||||
|
||||
`mysql-get-query-plan` takes one input parameter `sql_statement` and gets the execution plan for the SQL
|
||||
statement against the `source`.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
get_query_plan_tool:
|
||||
kind: mysql-get-query-plan
|
||||
source: my-mysql-instance
|
||||
description: Use this tool to get the execution plan for a sql statement.
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "mysql-get-query-plan". |
|
||||
| source | string | true | Name of the source the SQL should execute on. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
@@ -771,7 +771,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"version = \"0.23.0\" # x-release-please-version\n",
|
||||
"version = \"0.24.0\" # x-release-please-version\n",
|
||||
"! curl -L -o /content/toolbox https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
|
||||
"\n",
|
||||
"# Make the binary executable\n",
|
||||
|
||||
@@ -123,7 +123,7 @@ In this section, we will download and install the Toolbox binary.
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
export VERSION="0.23.0"
|
||||
export VERSION="0.24.0"
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v$VERSION/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -220,7 +220,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"version = \"0.23.0\" # x-release-please-version\n",
|
||||
"version = \"0.24.0\" # x-release-please-version\n",
|
||||
"! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
|
||||
"\n",
|
||||
"# Make the binary executable\n",
|
||||
|
||||
@@ -179,7 +179,7 @@ to use BigQuery, and then run the Toolbox server.
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ In this section, we will download Toolbox, configure our tools in a
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ In this section, we will download Toolbox and run the Toolbox server.
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ In this section, we will download Toolbox and run the Toolbox server.
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ In this section, we will download Toolbox and run the Toolbox server.
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "mcp-toolbox-for-databases",
|
||||
"version": "0.23.0",
|
||||
"version": "0.24.0",
|
||||
"description": "MCP Toolbox for Databases is an open-source MCP server for more than 30 different datasources.",
|
||||
"contextFileName": "MCP-TOOLBOX-EXTENSION.md"
|
||||
}
|
||||
6
go.mod
6
go.mod
@@ -22,7 +22,7 @@ require (
|
||||
github.com/cenkalti/backoff/v5 v5.0.3
|
||||
github.com/couchbase/gocb/v2 v2.11.1
|
||||
github.com/couchbase/tools-common/http v1.0.9
|
||||
github.com/elastic/elastic-transport-go/v8 v8.7.0
|
||||
github.com/elastic/elastic-transport-go/v8 v8.8.0
|
||||
github.com/elastic/go-elasticsearch/v9 v9.2.0
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/go-chi/chi/v5 v5.2.3
|
||||
@@ -33,7 +33,7 @@ require (
|
||||
github.com/go-playground/validator/v10 v10.28.0
|
||||
github.com/go-sql-driver/mysql v1.9.3
|
||||
github.com/goccy/go-yaml v1.18.0
|
||||
github.com/godror/godror v0.49.4
|
||||
github.com/godror/godror v0.49.6
|
||||
github.com/google/go-cmp v0.7.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.7.6
|
||||
@@ -42,7 +42,7 @@ require (
|
||||
github.com/microsoft/go-mssqldb v1.9.3
|
||||
github.com/nakagami/firebirdsql v0.9.15
|
||||
github.com/neo4j/neo4j-go-driver/v5 v5.28.4
|
||||
github.com/redis/go-redis/v9 v9.16.0
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/sijms/go-ora/v2 v2.9.0
|
||||
github.com/spf13/cobra v1.10.1
|
||||
github.com/thlib/go-timezone-local v0.0.7
|
||||
|
||||
12
go.sum
12
go.sum
@@ -822,8 +822,8 @@ github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3
|
||||
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/elastic/elastic-transport-go/v8 v8.7.0 h1:OgTneVuXP2uip4BA658Xi6Hfw+PeIOod2rY3GVMGoVE=
|
||||
github.com/elastic/elastic-transport-go/v8 v8.7.0/go.mod h1:YLHer5cj0csTzNFXoNQ8qhtGY1GTvSqPnKWKaqQE3Hk=
|
||||
github.com/elastic/elastic-transport-go/v8 v8.8.0 h1:7k1Ua+qluFr6p1jfJjGDl97ssJS/P7cHNInzfxgBQAo=
|
||||
github.com/elastic/elastic-transport-go/v8 v8.8.0/go.mod h1:YLHer5cj0csTzNFXoNQ8qhtGY1GTvSqPnKWKaqQE3Hk=
|
||||
github.com/elastic/go-elasticsearch/v9 v9.2.0 h1:COeL/g20+ixnUbffe4Wfbu88emrHjAq/LhVfmrjqRQs=
|
||||
github.com/elastic/go-elasticsearch/v9 v9.2.0/go.mod h1:2PB5YQPpY5tWbF65MRqzEXA31PZOdXCkloQSOZtU14I=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
@@ -915,8 +915,8 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
|
||||
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/godror/godror v0.49.4 h1:8kKWKoR17nPX7u10hr4GwD4u10hzTZED9ihdkuzRrKI=
|
||||
github.com/godror/godror v0.49.4/go.mod h1:kTMcxZzRw73RT5kn9v3JkBK4kHI6dqowHotqV72ebU8=
|
||||
github.com/godror/godror v0.49.6 h1:ts4ZGw8uLJ42e1D7aXmVuSrld0/lzUzmIUjuUuQOgGM=
|
||||
github.com/godror/godror v0.49.6/go.mod h1:kTMcxZzRw73RT5kn9v3JkBK4kHI6dqowHotqV72ebU8=
|
||||
github.com/godror/knownpb v0.3.0 h1:+caUdy8hTtl7X05aPl3tdL540TvCcaQA6woZQroLZMw=
|
||||
github.com/godror/knownpb v0.3.0/go.mod h1:PpTyfJwiOEAzQl7NtVCM8kdPCnp3uhxsZYIzZ5PV4zU=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
@@ -1222,8 +1222,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/prometheus/client_model v0.3.0/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w=
|
||||
github.com/redis/go-redis/v9 v9.16.0 h1:OotgqgLSRCmzfqChbQyG1PHC3tLNR89DG4jdOERSEP4=
|
||||
github.com/redis/go-redis/v9 v9.16.0/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
|
||||
@@ -32,16 +32,9 @@ tools:
|
||||
source: cloud-sql-mysql-source
|
||||
description: Lists top N (default 10) ongoing queries from processlist and innodb_trx, ordered by execution time in descending order. Returns detailed information of those queries in json format, including process id, query, transaction duration, transaction wait duration, process time, transaction state, process state, username with host, transaction rows locked, transaction rows modified, and db schema.
|
||||
get_query_plan:
|
||||
kind: mysql-sql
|
||||
kind: mysql-get-query-plan
|
||||
source: cloud-sql-mysql-source
|
||||
description: "Provide information about how MySQL executes a SQL statement. Common use cases include: 1) analyze query plan to improve its performance, and 2) determine effectiveness of existing indexes and evalueate new ones."
|
||||
statement: |
|
||||
EXPLAIN FORMAT=JSON {{.sql_statement}};
|
||||
templateParameters:
|
||||
- name: sql_statement
|
||||
type: string
|
||||
description: "the SQL statement to explain"
|
||||
required: true
|
||||
list_tables:
|
||||
kind: mysql-list-tables
|
||||
source: cloud-sql-mysql-source
|
||||
|
||||
@@ -36,16 +36,9 @@ tools:
|
||||
source: mysql-source
|
||||
description: Lists top N (default 10) ongoing queries from processlist and innodb_trx, ordered by execution time in descending order. Returns detailed information of those queries in json format, including process id, query, transaction duration, transaction wait duration, process time, transaction state, process state, username with host, transaction rows locked, transaction rows modified, and db schema.
|
||||
get_query_plan:
|
||||
kind: mysql-sql
|
||||
kind: mysql-get-query-plan
|
||||
source: mysql-source
|
||||
description: "Provide information about how MySQL executes a SQL statement. Common use cases include: 1) analyze query plan to improve its performance, and 2) determine effectiveness of existing indexes and evalueate new ones."
|
||||
statement: |
|
||||
EXPLAIN FORMAT=JSON {{.sql_statement}};
|
||||
templateParameters:
|
||||
- name: sql_statement
|
||||
type: string
|
||||
description: "the SQL statement to explain"
|
||||
required: true
|
||||
list_tables:
|
||||
kind: mysql-list-tables
|
||||
source: mysql-source
|
||||
|
||||
@@ -172,7 +172,14 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
accessToken := tools.AccessToken(r.Header.Get("Authorization"))
|
||||
|
||||
// Check if this specific tool requires the standard authorization header
|
||||
if tool.RequiresClientAuthorization(s.ResourceMgr) {
|
||||
clientAuth, err := tool.RequiresClientAuthorization(s.ResourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
s.logger.DebugContext(ctx, errMsg.Error())
|
||||
_ = render.Render(w, r, newErrResponse(errMsg, http.StatusNotFound))
|
||||
return
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
err = fmt.Errorf("tool requires client authorization but access token is missing from the request header")
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
@@ -255,7 +262,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
||||
if tool.RequiresClientAuthorization(s.ResourceMgr) {
|
||||
if clientAuth {
|
||||
// Propagate the original 401/403 error.
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err))
|
||||
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
||||
|
||||
@@ -77,9 +77,9 @@ func (t MockTool) Authorized(verifiedAuthServices []string) bool {
|
||||
return !t.unauthorized
|
||||
}
|
||||
|
||||
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) bool {
|
||||
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) (bool, error) {
|
||||
// defaulted to false
|
||||
return t.requiresClientAuthrorization
|
||||
return t.requiresClientAuthrorization, nil
|
||||
}
|
||||
|
||||
func (t MockTool) McpManifest() tools.McpManifest {
|
||||
@@ -119,8 +119,8 @@ func (t MockTool) McpManifest() tools.McpManifest {
|
||||
return mcpManifest
|
||||
}
|
||||
|
||||
func (t MockTool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t MockTool) GetAuthTokenHeaderName(tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
// MockPrompt is used to mock prompts in tests
|
||||
|
||||
@@ -108,10 +108,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
|
||||
// Get access token
|
||||
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
|
||||
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||
}
|
||||
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
|
||||
|
||||
// Check if this specific tool requires the standard authorization header
|
||||
if tool.RequiresClientAuthorization(resourceMgr) {
|
||||
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||
}
|
||||
@@ -183,7 +193,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if tool.RequiresClientAuthorization(resourceMgr) {
|
||||
if clientAuth {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
@@ -108,10 +108,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
|
||||
// Get access token
|
||||
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
|
||||
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||
}
|
||||
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
|
||||
|
||||
// Check if this specific tool requires the standard authorization header
|
||||
if tool.RequiresClientAuthorization(resourceMgr) {
|
||||
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||
}
|
||||
@@ -183,7 +193,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if tool.RequiresClientAuthorization(resourceMgr) {
|
||||
if clientAuth {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
@@ -101,10 +101,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
|
||||
// Get access token
|
||||
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
|
||||
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||
}
|
||||
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
|
||||
|
||||
// Check if this specific tool requires the standard authorization header
|
||||
if tool.RequiresClientAuthorization(resourceMgr) {
|
||||
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||
}
|
||||
@@ -176,7 +186,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if tool.RequiresClientAuthorization(resourceMgr) {
|
||||
if clientAuth {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
@@ -30,26 +30,6 @@ import (
|
||||
|
||||
const SourceKind string = "alloydb-admin"
|
||||
|
||||
type userAgentRoundTripper struct {
|
||||
userAgent string
|
||||
next http.RoundTripper
|
||||
}
|
||||
|
||||
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := *req
|
||||
newReq.Header = make(http.Header)
|
||||
for k, v := range req.Header {
|
||||
newReq.Header[k] = v
|
||||
}
|
||||
ua := newReq.Header.Get("User-Agent")
|
||||
if ua == "" {
|
||||
newReq.Header.Set("User-Agent", rt.userAgent)
|
||||
} else {
|
||||
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
|
||||
}
|
||||
return rt.next.RoundTrip(&newReq)
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
@@ -87,10 +67,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
var client *http.Client
|
||||
if r.UseClientOAuth {
|
||||
client = &http.Client{
|
||||
Transport: &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: http.DefaultTransport,
|
||||
},
|
||||
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
|
||||
}
|
||||
} else {
|
||||
// Use Application Default Credentials
|
||||
@@ -99,10 +76,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
return nil, fmt.Errorf("failed to find default credentials: %w", err)
|
||||
}
|
||||
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
|
||||
baseClient.Transport = &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: baseClient.Transport,
|
||||
}
|
||||
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
|
||||
client = baseClient
|
||||
}
|
||||
|
||||
@@ -136,6 +110,10 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) GetDefaultProject() string {
|
||||
return s.DefaultProject
|
||||
}
|
||||
|
||||
func (s *Source) GetService(ctx context.Context, accessToken string) (*alloydbrestapi.Service, error) {
|
||||
if s.UseClientOAuth {
|
||||
token := &oauth2.Token{AccessToken: accessToken}
|
||||
|
||||
@@ -29,26 +29,6 @@ import (
|
||||
const SourceKind string = "cloud-gemini-data-analytics"
|
||||
const Endpoint string = "https://geminidataanalytics.googleapis.com"
|
||||
|
||||
type userAgentRoundTripper struct {
|
||||
userAgent string
|
||||
next http.RoundTripper
|
||||
}
|
||||
|
||||
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := *req
|
||||
newReq.Header = make(http.Header)
|
||||
for k, v := range req.Header {
|
||||
newReq.Header[k] = v
|
||||
}
|
||||
ua := newReq.Header.Get("User-Agent")
|
||||
if ua == "" {
|
||||
newReq.Header.Set("User-Agent", rt.userAgent)
|
||||
} else {
|
||||
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
|
||||
}
|
||||
return rt.next.RoundTrip(&newReq)
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
@@ -87,10 +67,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
var client *http.Client
|
||||
if r.UseClientOAuth {
|
||||
client = &http.Client{
|
||||
Transport: &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: http.DefaultTransport,
|
||||
},
|
||||
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
|
||||
}
|
||||
} else {
|
||||
// Use Application Default Credentials
|
||||
@@ -100,10 +77,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
return nil, fmt.Errorf("failed to find default credentials: %w", err)
|
||||
}
|
||||
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
|
||||
baseClient.Transport = &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: baseClient.Transport,
|
||||
}
|
||||
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
|
||||
client = baseClient
|
||||
}
|
||||
|
||||
@@ -133,6 +107,14 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) GetProjectID() string {
|
||||
return s.ProjectID
|
||||
}
|
||||
|
||||
func (s *Source) GetBaseURL() string {
|
||||
return s.BaseURL
|
||||
}
|
||||
|
||||
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
|
||||
if s.UseClientOAuth {
|
||||
if accessToken == "" {
|
||||
@@ -140,10 +122,7 @@ func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Clien
|
||||
}
|
||||
token := &oauth2.Token{AccessToken: accessToken}
|
||||
baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
|
||||
baseClient.Transport = &userAgentRoundTripper{
|
||||
userAgent: s.userAgent,
|
||||
next: baseClient.Transport,
|
||||
}
|
||||
baseClient.Transport = util.NewUserAgentRoundTripper(s.userAgent, baseClient.Transport)
|
||||
return baseClient, nil
|
||||
}
|
||||
return s.Client, nil
|
||||
|
||||
@@ -29,26 +29,6 @@ import (
|
||||
|
||||
const SourceKind string = "cloud-monitoring"
|
||||
|
||||
type userAgentRoundTripper struct {
|
||||
userAgent string
|
||||
next http.RoundTripper
|
||||
}
|
||||
|
||||
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := *req
|
||||
newReq.Header = make(http.Header)
|
||||
for k, v := range req.Header {
|
||||
newReq.Header[k] = v
|
||||
}
|
||||
ua := newReq.Header.Get("User-Agent")
|
||||
if ua == "" {
|
||||
newReq.Header.Set("User-Agent", rt.userAgent)
|
||||
} else {
|
||||
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
|
||||
}
|
||||
return rt.next.RoundTrip(&newReq)
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
@@ -86,10 +66,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
var client *http.Client
|
||||
if r.UseClientOAuth {
|
||||
client = &http.Client{
|
||||
Transport: &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: http.DefaultTransport,
|
||||
},
|
||||
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
|
||||
}
|
||||
} else {
|
||||
// Use Application Default Credentials
|
||||
@@ -98,18 +75,15 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
return nil, fmt.Errorf("failed to find default credentials: %w", err)
|
||||
}
|
||||
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
|
||||
baseClient.Transport = &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: baseClient.Transport,
|
||||
}
|
||||
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
|
||||
client = baseClient
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Config: r,
|
||||
BaseURL: "https://monitoring.googleapis.com",
|
||||
Client: client,
|
||||
UserAgent: ua,
|
||||
baseURL: "https://monitoring.googleapis.com",
|
||||
client: client,
|
||||
userAgent: ua,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
@@ -118,9 +92,9 @@ var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
Config
|
||||
BaseURL string `yaml:"baseUrl"`
|
||||
Client *http.Client
|
||||
UserAgent string
|
||||
baseURL string
|
||||
client *http.Client
|
||||
userAgent string
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
@@ -131,6 +105,18 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) BaseURL() string {
|
||||
return s.baseURL
|
||||
}
|
||||
|
||||
func (s *Source) Client() *http.Client {
|
||||
return s.client
|
||||
}
|
||||
|
||||
func (s *Source) UserAgent() string {
|
||||
return s.userAgent
|
||||
}
|
||||
|
||||
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
|
||||
if s.UseClientOAuth {
|
||||
if accessToken == "" {
|
||||
@@ -139,7 +125,7 @@ func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Clien
|
||||
token := &oauth2.Token{AccessToken: accessToken}
|
||||
return oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)), nil
|
||||
}
|
||||
return s.Client, nil
|
||||
return s.client, nil
|
||||
}
|
||||
|
||||
func (s *Source) UseClientAuthorization() bool {
|
||||
|
||||
@@ -30,26 +30,6 @@ import (
|
||||
|
||||
const SourceKind string = "cloud-sql-admin"
|
||||
|
||||
type userAgentRoundTripper struct {
|
||||
userAgent string
|
||||
next http.RoundTripper
|
||||
}
|
||||
|
||||
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := *req
|
||||
newReq.Header = make(http.Header)
|
||||
for k, v := range req.Header {
|
||||
newReq.Header[k] = v
|
||||
}
|
||||
ua := newReq.Header.Get("User-Agent")
|
||||
if ua == "" {
|
||||
newReq.Header.Set("User-Agent", rt.userAgent)
|
||||
} else {
|
||||
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
|
||||
}
|
||||
return rt.next.RoundTrip(&newReq)
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
@@ -88,10 +68,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
var client *http.Client
|
||||
if r.UseClientOAuth {
|
||||
client = &http.Client{
|
||||
Transport: &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: http.DefaultTransport,
|
||||
},
|
||||
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
|
||||
}
|
||||
} else {
|
||||
// Use Application Default Credentials
|
||||
@@ -100,10 +77,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
return nil, fmt.Errorf("failed to find default credentials: %w", err)
|
||||
}
|
||||
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
|
||||
baseClient.Transport = &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: baseClient.Transport,
|
||||
}
|
||||
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
|
||||
client = baseClient
|
||||
}
|
||||
|
||||
@@ -136,6 +110,10 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) GetDefaultProject() string {
|
||||
return s.DefaultProject
|
||||
}
|
||||
|
||||
func (s *Source) GetService(ctx context.Context, accessToken string) (*sqladmin.Service, error) {
|
||||
if s.UseClientOAuth {
|
||||
token := &oauth2.Token{AccessToken: accessToken}
|
||||
|
||||
@@ -107,7 +107,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
|
||||
s := &Source{
|
||||
Config: r,
|
||||
Client: &client,
|
||||
client: &client,
|
||||
}
|
||||
return s, nil
|
||||
|
||||
@@ -117,7 +117,7 @@ var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
Config
|
||||
Client *http.Client
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
@@ -127,3 +127,19 @@ func (s *Source) SourceKind() string {
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) HttpDefaultHeaders() map[string]string {
|
||||
return s.DefaultHeaders
|
||||
}
|
||||
|
||||
func (s *Source) HttpBaseURL() string {
|
||||
return s.BaseURL
|
||||
}
|
||||
|
||||
func (s *Source) HttpQueryParams() map[string]string {
|
||||
return s.QueryParams
|
||||
}
|
||||
|
||||
func (s *Source) Client() *http.Client {
|
||||
return s.client
|
||||
}
|
||||
|
||||
@@ -160,10 +160,6 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) GetApiSettings() *rtl.ApiSettings {
|
||||
return s.ApiSettings
|
||||
}
|
||||
|
||||
func (s *Source) UseClientAuthorization() bool {
|
||||
return strings.ToLower(s.UseClientOAuth) != "false"
|
||||
}
|
||||
@@ -188,6 +184,30 @@ func (s *Source) GoogleCloudTokenSourceWithScope(ctx context.Context, scope stri
|
||||
return google.DefaultTokenSource(ctx, scope)
|
||||
}
|
||||
|
||||
func (s *Source) LookerClient() *v4.LookerSDK {
|
||||
return s.Client
|
||||
}
|
||||
|
||||
func (s *Source) LookerApiSettings() *rtl.ApiSettings {
|
||||
return s.ApiSettings
|
||||
}
|
||||
|
||||
func (s *Source) LookerShowHiddenFields() bool {
|
||||
return s.ShowHiddenFields
|
||||
}
|
||||
|
||||
func (s *Source) LookerShowHiddenModels() bool {
|
||||
return s.ShowHiddenModels
|
||||
}
|
||||
|
||||
func (s *Source) LookerShowHiddenExplores() bool {
|
||||
return s.ShowHiddenExplores
|
||||
}
|
||||
|
||||
func (s *Source) LookerSessionLength() int64 {
|
||||
return s.SessionLength
|
||||
}
|
||||
|
||||
func initGoogleCloudConnection(ctx context.Context) (oauth2.TokenSource, error) {
|
||||
cred, err := google.FindDefaultCredentials(ctx, geminidataanalytics.DefaultAuthScopes()...)
|
||||
if err != nil {
|
||||
|
||||
@@ -96,6 +96,14 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) GetProject() string {
|
||||
return s.Project
|
||||
}
|
||||
|
||||
func (s *Source) GetLocation() string {
|
||||
return s.Location
|
||||
}
|
||||
|
||||
func (s *Source) GetBatchControllerClient() *dataproc.BatchControllerClient {
|
||||
return s.Client
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the create-cluster tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -97,7 +102,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -107,7 +111,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the create-cluster tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
manifest tools.Manifest
|
||||
@@ -120,6 +123,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok || project == "" {
|
||||
@@ -151,7 +159,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'user' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -198,10 +206,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the create-instance tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -98,7 +103,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -108,7 +112,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the create-instance tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
manifest tools.Manifest
|
||||
@@ -121,6 +124,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok || project == "" {
|
||||
@@ -147,7 +155,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'instanceType' parameter; expected 'PRIMARY' or 'READ_POOL'")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -208,10 +216,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the create-user tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -98,7 +103,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -108,9 +112,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the create-user tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -121,6 +123,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok || project == "" {
|
||||
@@ -147,7 +154,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing 'userType' parameter; expected 'ALLOYDB_BUILT_IN' or 'ALLOYDB_IAM_USER'")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -208,10 +215,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-get-cluster"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the get-cluster tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -104,7 +109,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the get-cluster tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters
|
||||
|
||||
manifest tools.Manifest
|
||||
@@ -117,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -132,7 +141,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -167,10 +176,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-get-instance"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the get-instance tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -95,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -105,9 +110,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the get-instance tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters
|
||||
|
||||
AllParams parameters.Parameters
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -137,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'instance' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -172,10 +180,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-get-user"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the get-user tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -95,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -105,9 +110,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the get-user tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters
|
||||
|
||||
AllParams parameters.Parameters
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -137,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'user' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -172,10 +180,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-list-clusters"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the list-clusters tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -93,7 +99,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -103,9 +108,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the list-clusters tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -116,6 +119,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -127,7 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -162,10 +170,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-list-instances"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the list-instances tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -104,9 +109,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the list-instances tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -117,6 +120,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -132,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -167,10 +175,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-list-users"
|
||||
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Configuration for the list-users tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -104,9 +109,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the list-users tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -117,6 +120,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -132,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -167,10 +175,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -25,9 +25,9 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-wait-for-operation"
|
||||
@@ -89,6 +89,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
UseClientAuthorization() bool
|
||||
GetService(context.Context, string) (*alloydb.Service, error)
|
||||
}
|
||||
|
||||
// Config defines the configuration for the wait-for-operation tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -119,12 +125,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.DefaultProject
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
@@ -180,7 +186,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -194,19 +199,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Tool represents the wait-for-operation tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
|
||||
Source *alloydbadmin.Source
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
Client *http.Client
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
|
||||
// Polling configuration
|
||||
Delay time.Duration
|
||||
MaxDelay time.Duration
|
||||
Multiplier float64
|
||||
MaxRetries int
|
||||
|
||||
Client *http.Client
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -215,6 +217,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
@@ -230,7 +237,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("missing 'operation' parameter")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -363,10 +370,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
@@ -47,11 +46,6 @@ type compatibleSource interface {
|
||||
PostgresPool() *pgxpool.Pool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &alloydbpg.Source{}
|
||||
|
||||
var compatibleSources = [...]string{alloydbpg.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -70,18 +64,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
numParams := len(cfg.NLConfigParameters)
|
||||
quotedNameParts := make([]string, 0, numParams)
|
||||
placeholderParts := make([]string, 0, numParams)
|
||||
@@ -126,7 +108,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
Config: cfg,
|
||||
Parameters: cfg.NLConfigParameters,
|
||||
Statement: stmt,
|
||||
Pool: s.PostgresPool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.NLConfigParameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -139,9 +120,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Pool *pgxpool.Pool
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
@@ -152,6 +131,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pool := source.PostgresPool()
|
||||
|
||||
sliceParams := params.AsSlice()
|
||||
allParamValues := make([]any, len(sliceParams)+1)
|
||||
allParamValues[0] = fmt.Sprintf("%s", sliceParams[0]) // nl_question
|
||||
@@ -160,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
allParamValues[i+2] = fmt.Sprintf("%s", param)
|
||||
}
|
||||
|
||||
results, err := t.Pool.Query(ctx, t.Statement, allParamValues...)
|
||||
results, err := pool.Query(ctx, t.Statement, allParamValues...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w. Query: %v , Values: %v. Toolbox v0.19.0+ is only compatible with AlloyDB AI NL v1.0.3+. Please ensure that you are using the latest AlloyDB AI NL extension", err, t.Statement, allParamValues)
|
||||
}
|
||||
@@ -203,10 +188,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -57,11 +57,6 @@ type compatibleSource interface {
|
||||
BigQuerySession() bigqueryds.BigQuerySessionProvider
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -87,7 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
allowedDatasets := s.BigQueryAllowedDatasets()
|
||||
@@ -136,17 +131,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
AllowedDatasets: allowedDatasets,
|
||||
SessionProvider: s.BigQuerySession(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -156,17 +144,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
AllowedDatasets []string
|
||||
SessionProvider bigqueryds.BigQuerySessionProvider
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -175,23 +155,27 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke runs the contribution analysis.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
inputData, ok := paramsMap["input_data"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"])
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
restService := t.RestService
|
||||
var err error
|
||||
bqClient := source.BigQueryClient()
|
||||
restService := source.BigQueryRestService()
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, restService, err = t.ClientCreator(tokenStr, true)
|
||||
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -229,9 +213,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
var inputDataSource string
|
||||
trimmedUpperInputData := strings.TrimSpace(strings.ToUpper(inputData))
|
||||
if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
var connProps []*bigqueryapi.ConnectionProperty
|
||||
session, err := t.SessionProvider(ctx)
|
||||
session, err := source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
@@ -240,7 +224,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
{Key: "session_id", Value: session.ID},
|
||||
}
|
||||
}
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, inputData, nil, connProps)
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||
}
|
||||
@@ -252,7 +236,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
queryStats := dryRunJob.Statistics.Query
|
||||
if queryStats != nil {
|
||||
for _, tableRef := range queryStats.ReferencedTables {
|
||||
if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
||||
if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
||||
return nil, fmt.Errorf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
|
||||
}
|
||||
}
|
||||
@@ -262,18 +246,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
inputDataSource = fmt.Sprintf("(%s)", inputData)
|
||||
} else {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
parts := strings.Split(inputData, ".")
|
||||
var projectID, datasetID string
|
||||
switch len(parts) {
|
||||
case 3: // project.dataset.table
|
||||
projectID, datasetID = parts[0], parts[1]
|
||||
case 2: // dataset.table
|
||||
projectID, datasetID = t.Client.Project(), parts[0]
|
||||
projectID, datasetID = source.BigQueryClient().Project(), parts[0]
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid table ID format for 'input_data': %q. Expected 'dataset.table' or 'project.dataset.table'", inputData)
|
||||
}
|
||||
if !t.IsDatasetAllowed(projectID, datasetID) {
|
||||
if !source.IsDatasetAllowed(projectID, datasetID) {
|
||||
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, inputData)
|
||||
}
|
||||
}
|
||||
@@ -292,7 +276,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
// Get session from provider if in protected mode.
|
||||
// Otherwise, a new session will be created by the first query.
|
||||
session, err := t.SessionProvider(ctx)
|
||||
session, err := source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
@@ -385,10 +369,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -105,11 +104,6 @@ type CAPayload struct {
|
||||
ClientIdEnum string `json:"clientIdEnum"`
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -135,7 +129,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
allowedDatasets := s.BigQueryAllowedDatasets()
|
||||
@@ -153,31 +147,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
params := parameters.Parameters{userQueryParameter, tableRefsParameter}
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
// Get cloud-platform token source for Gemini Data Analytics API during initialization
|
||||
var bigQueryTokenSourceWithScope oauth2.TokenSource
|
||||
if !s.UseClientAuthorization() {
|
||||
ctx := context.Background()
|
||||
ts, err := s.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err)
|
||||
}
|
||||
bigQueryTokenSourceWithScope = ts
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Project: s.BigQueryProject(),
|
||||
Location: s.BigQueryLocation(),
|
||||
Parameters: params,
|
||||
Client: s.BigQueryClient(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
TokenSource: bigQueryTokenSourceWithScope,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
MaxQueryResultRows: s.GetMaxQueryResultRows(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
AllowedDatasets: allowedDatasets,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -187,18 +162,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project string
|
||||
Location string
|
||||
Client *bigqueryapi.Client
|
||||
TokenSource oauth2.TokenSource
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
MaxQueryResultRows int
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
AllowedDatasets []string
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -206,11 +172,15 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tokenStr string
|
||||
var err error
|
||||
|
||||
// Get credentials for the API call
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
// Use client-side access token
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", util.ErrUnauthorized)
|
||||
@@ -220,11 +190,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Get cloud-platform token source for Gemini Data Analytics API during initialization
|
||||
tokenSource, err := source.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err)
|
||||
}
|
||||
|
||||
// Use cloud-platform token source for Gemini Data Analytics API
|
||||
if t.TokenSource == nil {
|
||||
if tokenSource == nil {
|
||||
return nil, fmt.Errorf("cloud-platform token source is missing")
|
||||
}
|
||||
token, err := t.TokenSource.Token()
|
||||
token, err := tokenSource.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token from cloud-platform token source: %w", err)
|
||||
}
|
||||
@@ -245,17 +221,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
for _, tableRef := range tableRefs {
|
||||
if !t.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
|
||||
if !source.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
|
||||
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Construct URL, headers, and payload
|
||||
projectID := t.Project
|
||||
location := t.Location
|
||||
projectID := source.BigQueryProject()
|
||||
location := source.BigQueryLocation()
|
||||
if location == "" {
|
||||
location = "us"
|
||||
}
|
||||
@@ -279,7 +255,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
// Call the streaming API
|
||||
response, err := getStream(caURL, payload, headers, t.MaxQueryResultRows)
|
||||
response, err := getStream(caURL, payload, headers, source.GetMaxQueryResultRows())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get response from conversational analytics API: %w", err)
|
||||
}
|
||||
@@ -303,8 +279,12 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
// StreamMessage represents a single message object from the streaming API response.
|
||||
@@ -580,6 +560,6 @@ func appendMessage(messages []map[string]any, newMessage map[string]any) []map[s
|
||||
return append(messages, newMessage)
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -60,11 +60,6 @@ type compatibleSource interface {
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -90,7 +85,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
var sqlDescriptionBuilder strings.Builder
|
||||
@@ -136,18 +131,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
WriteMode: s.BigQueryWriteMode(),
|
||||
SessionProvider: s.BigQuerySession(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
AllowedDatasets: allowedDatasets,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -157,18 +144,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
WriteMode string
|
||||
SessionProvider bigqueryds.BigQuerySessionProvider
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
AllowedDatasets []string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -176,6 +154,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
sql, ok := paramsMap["sql"].(string)
|
||||
if !ok {
|
||||
@@ -186,17 +169,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"])
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
restService := t.RestService
|
||||
bqClient := source.BigQueryClient()
|
||||
restService := source.BigQueryRestService()
|
||||
|
||||
var err error
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, restService, err = t.ClientCreator(tokenStr, true)
|
||||
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -204,8 +186,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
var connProps []*bigqueryapi.ConnectionProperty
|
||||
var session *bigqueryds.Session
|
||||
if t.WriteMode == bigqueryds.WriteModeProtected {
|
||||
session, err = t.SessionProvider(ctx)
|
||||
if source.BigQueryWriteMode() == bigqueryds.WriteModeProtected {
|
||||
session, err = source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session for protected mode: %w", err)
|
||||
}
|
||||
@@ -221,7 +203,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
statementType := dryRunJob.Statistics.Query.StatementType
|
||||
|
||||
switch t.WriteMode {
|
||||
switch source.BigQueryWriteMode() {
|
||||
case bigqueryds.WriteModeBlocked:
|
||||
if statementType != "SELECT" {
|
||||
return nil, fmt.Errorf("write mode is 'blocked', only SELECT statements are allowed")
|
||||
@@ -235,7 +217,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
switch statementType {
|
||||
case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA":
|
||||
return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType)
|
||||
@@ -270,7 +252,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
} else if statementType != "SELECT" {
|
||||
// If dry run yields no tables, fall back to the parser for non-SELECT statements
|
||||
// to catch unsafe operations like EXECUTE IMMEDIATE.
|
||||
parsedTables, parseErr := bqutil.TableParser(sql, t.Client.Project())
|
||||
parsedTables, parseErr := bqutil.TableParser(sql, source.BigQueryClient().Project())
|
||||
if parseErr != nil {
|
||||
// If parsing fails (e.g., EXECUTE IMMEDIATE), we cannot guarantee safety, so we must fail.
|
||||
return nil, fmt.Errorf("could not parse tables from query to validate against allowed datasets: %w", parseErr)
|
||||
@@ -282,7 +264,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
parts := strings.Split(tableID, ".")
|
||||
if len(parts) == 3 {
|
||||
projectID, datasetID := parts[0], parts[1]
|
||||
if !t.IsDatasetAllowed(projectID, datasetID) {
|
||||
if !source.IsDatasetAllowed(projectID, datasetID) {
|
||||
return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID)
|
||||
}
|
||||
}
|
||||
@@ -374,10 +356,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -57,11 +57,6 @@ type compatibleSource interface {
|
||||
BigQuerySession() bigqueryds.BigQuerySessionProvider
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -87,7 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
allowedDatasets := s.BigQueryAllowedDatasets()
|
||||
@@ -116,17 +111,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
SessionProvider: s.BigQuerySession(),
|
||||
AllowedDatasets: allowedDatasets,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -136,17 +124,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
AllowedDatasets []string
|
||||
SessionProvider bigqueryds.BigQuerySessionProvider
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -154,6 +134,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
historyData, ok := paramsMap["history_data"].(string)
|
||||
if !ok {
|
||||
@@ -188,17 +173,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
restService := t.RestService
|
||||
var err error
|
||||
bqClient := source.BigQueryClient()
|
||||
restService := source.BigQueryRestService()
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, restService, err = t.ClientCreator(tokenStr, false)
|
||||
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -207,9 +191,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
var historyDataSource string
|
||||
trimmedUpperHistoryData := strings.TrimSpace(strings.ToUpper(historyData))
|
||||
if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
var connProps []*bigqueryapi.ConnectionProperty
|
||||
session, err := t.SessionProvider(ctx)
|
||||
session, err := source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
@@ -218,7 +202,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
{Key: "session_id", Value: session.ID},
|
||||
}
|
||||
}
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, historyData, nil, connProps)
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||
}
|
||||
@@ -230,7 +214,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
queryStats := dryRunJob.Statistics.Query
|
||||
if queryStats != nil {
|
||||
for _, tableRef := range queryStats.ReferencedTables {
|
||||
if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
||||
if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
||||
return nil, fmt.Errorf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
|
||||
}
|
||||
}
|
||||
@@ -240,7 +224,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
historyDataSource = fmt.Sprintf("(%s)", historyData)
|
||||
} else {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
parts := strings.Split(historyData, ".")
|
||||
var projectID, datasetID string
|
||||
|
||||
@@ -249,13 +233,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
projectID = parts[0]
|
||||
datasetID = parts[1]
|
||||
case 2: // dataset.table
|
||||
projectID = t.Client.Project()
|
||||
projectID = source.BigQueryClient().Project()
|
||||
datasetID = parts[0]
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid table ID format for 'history_data': %q. Expected 'dataset.table' or 'project.dataset.table'", historyData)
|
||||
}
|
||||
|
||||
if !t.IsDatasetAllowed(projectID, datasetID) {
|
||||
if !source.IsDatasetAllowed(projectID, datasetID) {
|
||||
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, historyData)
|
||||
}
|
||||
}
|
||||
@@ -279,7 +263,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
// JobStatistics.QueryStatistics.StatementType
|
||||
query := bqClient.Query(sql)
|
||||
query.Location = bqClient.Location
|
||||
session, err := t.SessionProvider(ctx)
|
||||
session, err := source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
@@ -349,10 +333,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -54,11 +54,6 @@ type compatibleSource interface {
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -84,7 +79,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
defaultProjectID := s.BigQueryProject()
|
||||
@@ -104,14 +99,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -121,15 +112,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
Statement string
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -137,6 +122,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
if !ok {
|
||||
@@ -148,22 +138,21 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
var err error
|
||||
bqClient := source.BigQueryClient()
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !t.IsDatasetAllowed(projectId, datasetId) {
|
||||
if !source.IsDatasetAllowed(projectId, datasetId) {
|
||||
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
||||
}
|
||||
|
||||
@@ -193,10 +182,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -55,11 +55,6 @@ type compatibleSource interface {
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -85,7 +80,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
defaultProjectID := s.BigQueryProject()
|
||||
@@ -108,14 +103,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -125,15 +116,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
Statement string
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -141,6 +126,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
if !ok {
|
||||
@@ -157,20 +147,19 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", tableKey)
|
||||
}
|
||||
|
||||
if !t.IsDatasetAllowed(projectId, datasetId) {
|
||||
if !source.IsDatasetAllowed(projectId, datasetId) {
|
||||
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
bqClient := source.BigQueryClient()
|
||||
|
||||
var err error
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -203,10 +192,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -52,11 +52,6 @@ type compatibleSource interface {
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -82,7 +77,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
var projectParameter parameters.Parameter
|
||||
@@ -103,14 +98,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
AllowedDatasets: allowedDatasets,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -120,15 +111,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
Statement string
|
||||
AllowedDatasets []string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -136,8 +121,13 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
return t.AllowedDatasets, nil
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
return source.BigQueryAllowedDatasets(), nil
|
||||
}
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
@@ -145,14 +135,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
bqClient := source.BigQueryClient()
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -197,10 +187,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -55,11 +55,6 @@ type compatibleSource interface {
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -85,7 +80,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
defaultProjectID := s.BigQueryProject()
|
||||
@@ -107,14 +102,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -124,15 +115,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -140,6 +125,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
if !ok {
|
||||
@@ -151,18 +141,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
||||
}
|
||||
|
||||
if !t.IsDatasetAllowed(projectId, datasetId) {
|
||||
if !source.IsDatasetAllowed(projectId, datasetId) {
|
||||
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
bqClient := source.BigQueryClient()
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -208,10 +198,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -51,11 +51,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -72,20 +67,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// Initialize the search configuration with the provided sources
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
// Get the Dataplex client using the method from the source
|
||||
makeCatalogClient := s.MakeDataplexCatalogClient()
|
||||
|
||||
prompt := parameters.NewStringParameter("prompt", "Prompt representing search intention. Do not rewrite the prompt.")
|
||||
datasetIds := parameters.NewArrayParameterWithDefault("datasetIds", []any{}, "Array of dataset IDs.", parameters.NewStringParameter("datasetId", "The IDs of the bigquery dataset."))
|
||||
projectIds := parameters.NewArrayParameterWithDefault("projectIds", []any{}, "Array of project IDs.", parameters.NewStringParameter("projectId", "The IDs of the bigquery project."))
|
||||
@@ -100,11 +81,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, params, nil)
|
||||
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
MakeCatalogClient: makeCatalogClient,
|
||||
ProjectID: s.BigQueryProject(),
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: params.Manifest(),
|
||||
@@ -117,12 +95,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters
|
||||
UseClientOAuth bool
|
||||
MakeCatalogClient func() (*dataplexapi.CatalogClient, bigqueryds.DataplexClientCreator, error)
|
||||
ProjectID string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -133,8 +108,12 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func constructSearchQueryHelper(predicate string, operator string, items []string) string {
|
||||
@@ -207,6 +186,11 @@ func ExtractType(resourceString string) string {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
pageSize := int32(paramsMap["pageSize"].(int))
|
||||
prompt, _ := paramsMap["prompt"].(string)
|
||||
@@ -228,14 +212,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
req := &dataplexpb.SearchEntriesRequest{
|
||||
Query: fmt.Sprintf("%s %s", prompt, constructSearchQuery(projectIds, datasetIds, types)),
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID),
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", source.BigQueryProject()),
|
||||
PageSize: pageSize,
|
||||
SemanticSearch: true,
|
||||
}
|
||||
|
||||
catalogClient, dataplexClientCreator, _ := t.MakeCatalogClient()
|
||||
catalogClient, dataplexClientCreator, _ := source.MakeDataplexCatalogClient()()
|
||||
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
@@ -248,7 +232,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
it := catalogClient.SearchEntries(ctx, req)
|
||||
if it == nil {
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID)
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.BigQueryProject())
|
||||
}
|
||||
|
||||
var results []Response
|
||||
@@ -288,6 +272,6 @@ func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -57,11 +57,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -81,18 +76,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -102,15 +85,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
SessionProvider: s.BigQuerySession(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -120,15 +98,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
SessionProvider bigqueryds.BigQuerySessionProvider
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -136,6 +108,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
|
||||
lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters))
|
||||
|
||||
@@ -212,16 +189,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
lowLevelParams = append(lowLevelParams, lowLevelParam)
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
restService := t.RestService
|
||||
bqClient := source.BigQueryClient()
|
||||
restService := source.BigQueryRestService()
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, restService, err = t.ClientCreator(tokenStr, true)
|
||||
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -232,8 +209,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
query.Location = bqClient.Location
|
||||
|
||||
connProps := []*bigqueryapi.ConnectionProperty{}
|
||||
if t.SessionProvider != nil {
|
||||
session, err := t.SessionProvider(ctx)
|
||||
if source.BigQuerySession() != nil {
|
||||
session, err := source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
@@ -311,10 +288,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"cloud.google.com/go/bigtable"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
bigtabledb "github.com/googleapis/genai-toolbox/internal/sources/bigtable"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -46,11 +45,6 @@ type compatibleSource interface {
|
||||
BigtableClient() *bigtable.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigtabledb.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigtabledb.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -70,18 +64,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -93,7 +75,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Client: s.BigtableClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -105,9 +86,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Client *bigtable.Client
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -156,6 +135,11 @@ func getMapParamsType(tparams parameters.Parameters, params parameters.ParamValu
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
@@ -172,7 +156,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("fail to get map params: %w", err)
|
||||
}
|
||||
|
||||
ps, err := t.Client.PrepareStatement(
|
||||
ps, err := source.BigtableClient().PrepareStatement(
|
||||
ctx,
|
||||
newStatement,
|
||||
mapParamsType,
|
||||
@@ -224,10 +208,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
gocql "github.com/apache/cassandra-gocql-driver/v2"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cassandra"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -46,10 +45,6 @@ type compatibleSource interface {
|
||||
CassandraSession() *gocql.Session
|
||||
}
|
||||
|
||||
var _ compatibleSource = &cassandra.Source{}
|
||||
|
||||
var compatibleSources = [...]string{cassandra.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -61,20 +56,15 @@ type Config struct {
|
||||
TemplateParameters parameters.Parameters `yaml:"templateParameters"`
|
||||
}
|
||||
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
// ToolConfigKind implements tools.ToolConfig.
|
||||
func (c Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
// Initialize implements tools.ToolConfig.
|
||||
func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[c.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", c.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, err := parameters.ProcessParameters(c.TemplateParameters, c.Parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -85,25 +75,17 @@ func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
t := Tool{
|
||||
Config: c,
|
||||
AllParams: allParameters,
|
||||
Session: s.CassandraSession(),
|
||||
manifest: tools.Manifest{Description: c.Description, Parameters: paramManifest, AuthRequired: c.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// ToolConfigKind implements tools.ToolConfig.
|
||||
func (c Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
var _ tools.ToolConfig = Config{}
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Session *gocql.Session
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -113,8 +95,8 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
// RequiresClientAuthorization implements tools.Tool.
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Authorized implements tools.Tool.
|
||||
@@ -124,6 +106,11 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
|
||||
// Invoke implements tools.Tool.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
@@ -135,7 +122,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
sliceParams := newParams.AsSlice()
|
||||
iter := t.Session.Query(newStatement, sliceParams...).IterContext(ctx)
|
||||
iter := source.CassandraSession().Query(newStatement, sliceParams...).IterContext(ctx)
|
||||
|
||||
// Create a slice to store the out
|
||||
var out []map[string]interface{}
|
||||
@@ -170,8 +157,6 @@ func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any)
|
||||
return parameters.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -25,12 +25,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
var compatibleSources = []string{"clickhouse"}
|
||||
|
||||
const executeSQLKind string = "clickhouse-execute-sql"
|
||||
|
||||
func init() {
|
||||
@@ -47,6 +41,10 @@ func newExecuteSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -62,16 +60,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", executeSQLKind, compatibleSources)
|
||||
}
|
||||
|
||||
sqlParameter := parameters.NewStringParameter("sql", "The SQL statement to execute.")
|
||||
params := parameters.Parameters{sqlParameter}
|
||||
|
||||
@@ -80,7 +68,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Pool: s.ClickHousePool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -91,9 +78,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Pool *sql.DB
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -103,13 +88,18 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
sql, ok := paramsMap["sql"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast sql parameter %s", paramsMap["sql"])
|
||||
}
|
||||
|
||||
results, err := t.Pool.QueryContext(ctx, sql)
|
||||
results, err := source.ClickHousePool().QueryContext(ctx, sql)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
@@ -183,10 +173,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -25,12 +25,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
var compatibleSources = []string{"clickhouse"}
|
||||
|
||||
const listDatabasesKind string = "clickhouse-list-databases"
|
||||
|
||||
func init() {
|
||||
@@ -47,6 +41,10 @@ func newListDatabasesConfig(ctx context.Context, name string, decoder *yaml.Deco
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -63,23 +61,12 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", listDatabasesKind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, _ := parameters.ProcessParameters(nil, cfg.Parameters)
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
|
||||
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Pool: s.ClickHousePool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -90,9 +77,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Pool *sql.DB
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -102,10 +87,15 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Query to list all databases
|
||||
query := "SHOW DATABASES"
|
||||
|
||||
results, err := t.Pool.QueryContext(ctx, query)
|
||||
results, err := source.ClickHousePool().QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
@@ -146,10 +136,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -32,21 +31,6 @@ func TestListDatabasesConfigToolConfigKind(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDatabasesConfigInitializeMissingSource(t *testing.T) {
|
||||
cfg := Config{
|
||||
Name: "test-list-databases",
|
||||
Kind: listDatabasesKind,
|
||||
Source: "missing-source",
|
||||
Description: "Test list databases tool",
|
||||
}
|
||||
|
||||
srcs := map[string]sources.Source{}
|
||||
_, err := cfg.Initialize(srcs)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing source")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFromYamlClickHouseListDatabases(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
|
||||
@@ -25,12 +25,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
var compatibleSources = []string{"clickhouse"}
|
||||
|
||||
const listTablesKind string = "clickhouse-list-tables"
|
||||
const databaseKey string = "database"
|
||||
|
||||
@@ -48,6 +42,10 @@ func newListTablesConfig(ctx context.Context, name string, decoder *yaml.Decoder
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -64,16 +62,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", listTablesKind, compatibleSources)
|
||||
}
|
||||
|
||||
databaseParameter := parameters.NewStringParameter(databaseKey, "The database to list tables from.")
|
||||
params := parameters.Parameters{databaseParameter}
|
||||
|
||||
@@ -83,7 +71,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Pool: s.ClickHousePool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -94,9 +81,7 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
|
||||
Pool *sql.DB
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -106,6 +91,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
database, ok := mapParams[databaseKey].(string)
|
||||
if !ok {
|
||||
@@ -115,7 +105,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
// Query to list all tables in the specified database
|
||||
query := fmt.Sprintf("SHOW TABLES FROM %s", database)
|
||||
|
||||
results, err := t.Pool.QueryContext(ctx, query)
|
||||
results, err := source.ClickHousePool().QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
@@ -157,10 +147,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -32,21 +31,6 @@ func TestListTablesConfigToolConfigKind(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestListTablesConfigInitializeMissingSource(t *testing.T) {
|
||||
cfg := Config{
|
||||
Name: "test-list-tables",
|
||||
Kind: listTablesKind,
|
||||
Source: "missing-source",
|
||||
Description: "Test list tables tool",
|
||||
}
|
||||
|
||||
srcs := map[string]sources.Source{}
|
||||
_, err := cfg.Initialize(srcs)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing source")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFromYamlClickHouseListTables(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
|
||||
@@ -25,21 +25,15 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
var compatibleSources = []string{"clickhouse"}
|
||||
|
||||
const sqlKind string = "clickhouse-sql"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(sqlKind, newSQLConfig) {
|
||||
if !tools.Register(sqlKind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", sqlKind))
|
||||
}
|
||||
}
|
||||
|
||||
func newSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
@@ -47,6 +41,10 @@ func newSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tool
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -65,23 +63,12 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", sqlKind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, _ := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
|
||||
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Pool: s.ClickHousePool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -93,7 +80,6 @@ var _ tools.Tool = Tool{}
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
Pool *sql.DB
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -103,6 +89,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
@@ -115,7 +106,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
sliceParams := newParams.AsSlice()
|
||||
results, err := t.Pool.QueryContext(ctx, newStatement, sliceParams...)
|
||||
results, err := source.ClickHousePool().QueryContext(ctx, newStatement, sliceParams...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
@@ -191,10 +182,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return false
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -142,66 +142,6 @@ func TestSQLConfigInitializeValidSource(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLConfigInitializeMissingSource(t *testing.T) {
|
||||
config := Config{
|
||||
Name: "test-tool",
|
||||
Kind: sqlKind,
|
||||
Source: "missing-source",
|
||||
Description: "Test tool",
|
||||
Statement: "SELECT 1",
|
||||
Parameters: parameters.Parameters{},
|
||||
}
|
||||
|
||||
sources := map[string]sources.Source{}
|
||||
|
||||
_, err := config.Initialize(sources)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing source, got nil")
|
||||
}
|
||||
|
||||
expectedErr := `no source named "missing-source" configured`
|
||||
if err.Error() != expectedErr {
|
||||
t.Errorf("Expected error %q, got %q", expectedErr, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// mockIncompatibleSource is a mock source that doesn't implement the compatibleSource interface
|
||||
type mockIncompatibleSource struct{}
|
||||
|
||||
func (m *mockIncompatibleSource) SourceKind() string {
|
||||
return "mock"
|
||||
}
|
||||
|
||||
func (m *mockIncompatibleSource) ToConfig() sources.SourceConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSQLConfigInitializeIncompatibleSource(t *testing.T) {
|
||||
config := Config{
|
||||
Name: "test-tool",
|
||||
Kind: sqlKind,
|
||||
Source: "incompatible-source",
|
||||
Description: "Test tool",
|
||||
Statement: "SELECT 1",
|
||||
Parameters: parameters.Parameters{},
|
||||
}
|
||||
|
||||
mockSource := &mockIncompatibleSource{}
|
||||
|
||||
sources := map[string]sources.Source{
|
||||
"incompatible-source": mockSource,
|
||||
}
|
||||
|
||||
_, err := config.Initialize(sources)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for incompatible source, got nil")
|
||||
}
|
||||
|
||||
if err.Error() == "" {
|
||||
t.Error("Expected non-empty error message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolManifest(t *testing.T) {
|
||||
tool := Tool{
|
||||
manifest: tools.Manifest{
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
@@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetProjectID() string
|
||||
GetBaseURL() string
|
||||
UseClientAuthorization() bool
|
||||
GetClient(context.Context, string) (*http.Client, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -64,18 +70,6 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*cloudgdasrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-gemini-data-analytics`", kind)
|
||||
}
|
||||
|
||||
// Define the parameters for the Gemini Data Analytics Query API
|
||||
// The prompt is the only input parameter.
|
||||
allParameters := parameters.Parameters{
|
||||
@@ -87,7 +81,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Source: s,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
@@ -99,7 +92,6 @@ var _ tools.Tool = Tool{}
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters
|
||||
Source *cloudgdasrc.Source
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -110,6 +102,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
|
||||
// Invoke executes the tool logic
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
prompt, ok := paramsMap["prompt"].(string)
|
||||
if !ok {
|
||||
@@ -118,11 +115,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
// The API endpoint itself always uses the "global" location.
|
||||
apiLocation := "global"
|
||||
apiParent := fmt.Sprintf("projects/%s/locations/%s", t.Source.ProjectID, apiLocation)
|
||||
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", t.Source.BaseURL, apiParent)
|
||||
apiParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), apiLocation)
|
||||
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", source.GetBaseURL(), apiParent)
|
||||
|
||||
// The parent in the request payload uses the tool's configured location.
|
||||
payloadParent := fmt.Sprintf("projects/%s/locations/%s", t.Source.ProjectID, t.Location)
|
||||
payloadParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), t.Location)
|
||||
|
||||
payload := &QueryDataRequest{
|
||||
Parent: payloadParent,
|
||||
@@ -138,7 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
// Parse the access token if provided
|
||||
var tokenStr string
|
||||
if t.RequiresClientAuthorization(resourceMgr) {
|
||||
if source.UseClientAuthorization() {
|
||||
var err error
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
@@ -146,7 +143,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
|
||||
client, err := t.Source.GetClient(ctx, tokenStr)
|
||||
client, err := source.GetClient(ctx, tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get HTTP client: %w", err)
|
||||
}
|
||||
@@ -196,10 +193,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(_ tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
@@ -172,9 +173,8 @@ func TestInitialize(t *testing.T) {
|
||||
}
|
||||
|
||||
tcs := []struct {
|
||||
desc string
|
||||
cfg cloudgdatool.Config
|
||||
expectErr bool
|
||||
desc string
|
||||
cfg cloudgdatool.Config
|
||||
}{
|
||||
{
|
||||
desc: "successful initialization",
|
||||
@@ -185,29 +185,6 @@ func TestInitialize(t *testing.T) {
|
||||
Description: "Test Description",
|
||||
Location: "us-central1",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
desc: "missing source",
|
||||
cfg: cloudgdatool.Config{
|
||||
Name: "my-gda-query-tool",
|
||||
Kind: "cloud-gemini-data-analytics-query",
|
||||
Source: "non-existent-source",
|
||||
Description: "Test Description",
|
||||
Location: "us-central1",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
desc: "incompatible source kind",
|
||||
cfg: cloudgdatool.Config{
|
||||
Name: "my-gda-query-tool",
|
||||
Kind: "cloud-gemini-data-analytics-query",
|
||||
Source: "incompatible-source",
|
||||
Description: "Test Description",
|
||||
Location: "us-central1",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -219,16 +196,11 @@ func TestInitialize(t *testing.T) {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tool, err := tc.cfg.Initialize(srcs)
|
||||
if tc.expectErr && err == nil {
|
||||
t.Fatalf("expected an error but got none")
|
||||
}
|
||||
if !tc.expectErr && err != nil {
|
||||
if err != nil {
|
||||
t.Fatalf("did not expect an error but got: %v", err)
|
||||
}
|
||||
if !tc.expectErr {
|
||||
// Basic sanity check on the returned tool
|
||||
_ = tool // Avoid unused variable error
|
||||
}
|
||||
// Basic sanity check on the returned tool
|
||||
_ = tool // Avoid unused variable error
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -361,8 +333,10 @@ func TestInvoke(t *testing.T) {
|
||||
{Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"},
|
||||
}
|
||||
|
||||
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil)
|
||||
|
||||
// Invoke the tool
|
||||
result, err := tool.Invoke(ctx, nil, params, "") // No accessToken needed for ADC client
|
||||
result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client
|
||||
if err != nil {
|
||||
t.Fatalf("tool invocation failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -62,11 +62,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -83,35 +78,16 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
urlParameter := parameters.NewStringParameter(pageURLKey, "The full URL of the FHIR page to fetch. This would be the value of `Bundle.entry.link.url` field within the response returned from FHIR search or FHIR patient everything operations.")
|
||||
params := parameters.Parameters{urlParameter}
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -121,14 +97,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -136,13 +107,18 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url, ok := params.AsMap()[pageURLKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey)
|
||||
}
|
||||
|
||||
var httpClient *http.Client
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
@@ -150,7 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tokenStr})
|
||||
httpClient = oauth2.NewClient(ctx, ts)
|
||||
} else {
|
||||
// The t.Service object holds a client with the default credentials.
|
||||
// The source.Service() object holds a client with the default credentials.
|
||||
// However, the client is not exported, so we have to create a new one.
|
||||
var err error
|
||||
httpClient, err = google.DefaultClient(ctx, healthcare.CloudHealthcareScope)
|
||||
@@ -201,10 +177,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -62,11 +62,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -92,7 +87,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
idParameter := parameters.NewStringParameter(patientIDKey, "The ID of the patient FHIR resource for which the information is required")
|
||||
@@ -106,17 +101,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -126,15 +114,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -142,7 +124,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -151,20 +138,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey)
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", t.Project, t.Region, t.Dataset, storeID, patientID)
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", source.Project(), source.Region(), source.DatasetID(), storeID, patientID)
|
||||
var opts []googleapi.CallOption
|
||||
if val, ok := params.AsMap()[typeFilterKey]; ok {
|
||||
types, ok := val.([]any)
|
||||
@@ -225,10 +212,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -78,11 +78,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -108,7 +103,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{
|
||||
@@ -140,17 +135,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -160,15 +148,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -176,19 +158,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -261,7 +248,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
opts = append(opts, googleapi.QueryParameter("_summary", "text"))
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.SearchType(name, "Patient", &healthcare.SearchResourcesRequest{ResourceType: "Patient"}).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search patient resources: %w", err)
|
||||
@@ -298,10 +285,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -51,11 +51,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -72,33 +67,15 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -108,13 +85,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
Project, Region, Dataset string
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -122,22 +95,26 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
svc := t.Service
|
||||
var err error
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset)
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
|
||||
dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
@@ -161,10 +138,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -53,11 +53,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedDICOMStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.DicomStores.Get(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get DICOM store %q: %w", storeName, err)
|
||||
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -53,11 +53,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedDICOMStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.DicomStores.GetDICOMStoreMetrics(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metrics for DICOM store %q: %w", storeName, err)
|
||||
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -59,11 +59,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -89,7 +84,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
typeParameter := parameters.NewStringParameter(typeKey, "The FHIR resource type to retrieve (e.g., Patient, Observation).")
|
||||
@@ -102,17 +97,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -122,15 +110,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -138,7 +120,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -152,20 +139,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey)
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", t.Project, t.Region, t.Dataset, storeID, resType, resID)
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", source.Project(), source.Region(), source.DatasetID(), storeID, resType, resID)
|
||||
call := svc.Projects.Locations.Datasets.FhirStores.Fhir.Read(name)
|
||||
call.Header().Set("Content-Type", "application/fhir+json;charset=utf-8")
|
||||
resp, err := call.Do()
|
||||
@@ -204,10 +191,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -53,11 +53,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.FhirStores.Get(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get FHIR store %q: %w", storeName, err)
|
||||
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -53,11 +53,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.FhirStores.GetFHIRStoreMetrics(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metrics for FHIR store %q: %w", storeName, err)
|
||||
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -53,11 +53,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -74,34 +69,15 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedDICOMStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -111,15 +87,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -127,29 +97,33 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
svc := t.Service
|
||||
var err error
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset)
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
|
||||
stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
}
|
||||
var filtered []*healthcare.DicomStore
|
||||
for _, store := range stores.DicomStores {
|
||||
if len(t.AllowedStores) == 0 {
|
||||
if len(source.AllowedDICOMStores()) == 0 {
|
||||
filtered = append(filtered, store)
|
||||
continue
|
||||
}
|
||||
@@ -157,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
continue
|
||||
}
|
||||
parts := strings.Split(store.Name, "/")
|
||||
if _, ok := t.AllowedStores[parts[len(parts)-1]]; ok {
|
||||
if _, ok := source.AllowedDICOMStores()[parts[len(parts)-1]]; ok {
|
||||
filtered = append(filtered, store)
|
||||
}
|
||||
}
|
||||
@@ -180,10 +154,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -53,11 +53,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -74,34 +69,15 @@ func (cfg Config) ToolConfigKind() string {
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{}
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedFHIRStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -111,15 +87,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -127,29 +97,33 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
svc := t.Service
|
||||
var err error
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset)
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
|
||||
stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
}
|
||||
var filtered []*healthcare.FhirStore
|
||||
for _, store := range stores.FhirStores {
|
||||
if len(t.AllowedStores) == 0 {
|
||||
if len(source.AllowedFHIRStores()) == 0 {
|
||||
filtered = append(filtered, store)
|
||||
continue
|
||||
}
|
||||
@@ -157,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
continue
|
||||
}
|
||||
parts := strings.Split(store.Name, "/")
|
||||
if _, ok := t.AllowedStores[parts[len(parts)-1]]; ok {
|
||||
if _, ok := source.AllowedFHIRStores()[parts[len(parts)-1]]; ok {
|
||||
filtered = append(filtered, store)
|
||||
}
|
||||
}
|
||||
@@ -180,10 +154,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -61,11 +61,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -91,7 +86,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{
|
||||
@@ -107,17 +102,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedDICOMStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -127,15 +115,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -143,19 +125,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -177,7 +164,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid '%s' parameter; expected an integer", frameNumberKey)
|
||||
}
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
dicomWebPath := fmt.Sprintf("studies/%s/series/%s/instances/%s/frames/%d/rendered", study, series, sop, frame)
|
||||
call := svc.Projects.Locations.Datasets.DicomStores.Studies.Series.Instances.Frames.RetrieveRendered(name, dicomWebPath)
|
||||
call.Header().Set("Accept", "image/jpeg")
|
||||
@@ -214,10 +201,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -68,11 +68,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -98,7 +93,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{
|
||||
@@ -121,17 +116,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedDICOMStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -141,15 +129,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -157,19 +139,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -204,7 +191,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search dicom instances: %w", err)
|
||||
@@ -244,10 +231,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
@@ -65,11 +65,6 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &healthcareds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{healthcareds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
@@ -95,7 +90,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
params := parameters.Parameters{
|
||||
@@ -117,17 +112,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
Project: s.Project(),
|
||||
Region: s.Region(),
|
||||
Dataset: s.DatasetID(),
|
||||
AllowedStores: s.AllowedDICOMStores(),
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ServiceCreator: s.ServiceCreator(),
|
||||
Service: s.Service(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Config: cfg,
|
||||
Parameters: params,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -137,15 +125,9 @@ var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
|
||||
Project, Region, Dataset string
|
||||
AllowedStores map[string]struct{}
|
||||
Service *healthcare.Service
|
||||
ServiceCreator healthcareds.HealthcareServiceCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
@@ -153,19 +135,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := t.Service
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = t.ServiceCreator(tokenStr)
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
@@ -187,7 +174,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForSeries(name, dicomWebPath).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search dicom series: %w", err)
|
||||
@@ -227,10 +214,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.UseClientOAuth
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user