mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-12 17:09:48 -05:00
Compare commits
14 Commits
source-imp
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9c3720e31d | ||
|
|
3cd3c39d66 | ||
|
|
0691a6f715 | ||
|
|
467b96a23b | ||
|
|
4abf0c39e7 | ||
|
|
dd7b9de623 | ||
|
|
41b518b955 | ||
|
|
6e8a9eb8ec | ||
|
|
ef8f3b02f2 | ||
|
|
351b007fe3 | ||
|
|
9d1feca108 | ||
|
|
17b41f6453 | ||
|
|
a7799757c9 | ||
|
|
d961e373e1 |
@@ -51,6 +51,10 @@ ignoreFiles = ["quickstart/shared", "quickstart/python", "quickstart/js", "quick
|
||||
# Add a new version block here before every release
|
||||
# The order of versions in this file is mirrored into the dropdown
|
||||
|
||||
[[params.versions]]
|
||||
version = "v0.25.0"
|
||||
url = "https://googleapis.github.io/genai-toolbox/v0.25.0/"
|
||||
|
||||
[[params.versions]]
|
||||
version = "v0.24.0"
|
||||
url = "https://googleapis.github.io/genai-toolbox/v0.24.0/"
|
||||
|
||||
25
CHANGELOG.md
25
CHANGELOG.md
@@ -1,5 +1,30 @@
|
||||
# Changelog
|
||||
|
||||
## [0.25.0](https://github.com/googleapis/genai-toolbox/compare/v0.24.0...v0.25.0) (2026-01-08)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* Add `embeddingModel` support ([#2121](https://github.com/googleapis/genai-toolbox/issues/2121)) ([9c62f31](https://github.com/googleapis/genai-toolbox/commit/9c62f313ff5edf0a3b5b8a3e996eba078fba4095))
|
||||
* Add `allowed-hosts` flag ([#2254](https://github.com/googleapis/genai-toolbox/issues/2254)) ([17b41f6](https://github.com/googleapis/genai-toolbox/commit/17b41f64531b8fe417c28ada45d1992ba430dc1b))
|
||||
* Add parameter default value to manifest ([#2264](https://github.com/googleapis/genai-toolbox/issues/2264)) ([9d1feca](https://github.com/googleapis/genai-toolbox/commit/9d1feca10810fa42cb4c94a409252f1bd373ee36))
|
||||
* **snowflake:** Add Snowflake Source and Tools ([#858](https://github.com/googleapis/genai-toolbox/issues/858)) ([b706b5b](https://github.com/googleapis/genai-toolbox/commit/b706b5bc685aeda277f277868bae77d38d5fd7b6))
|
||||
* **prebuilt/cloud-sql-mysql:** Update CSQL MySQL prebuilt tools to use IAM ([#2202](https://github.com/googleapis/genai-toolbox/issues/2202)) ([731a32e](https://github.com/googleapis/genai-toolbox/commit/731a32e5360b4d6862d81fcb27d7127c655679a8))
|
||||
* **sources/bigquery:** Make credentials scope configurable ([#2210](https://github.com/googleapis/genai-toolbox/issues/2210)) ([a450600](https://github.com/googleapis/genai-toolbox/commit/a4506009b93771b77fb05ae97044f914967e67ed))
|
||||
* **sources/trino:** Add ssl verification options and fix docs example ([#2155](https://github.com/googleapis/genai-toolbox/issues/2155)) ([4a4cf1e](https://github.com/googleapis/genai-toolbox/commit/4a4cf1e712b671853678dba99c4dc49dd4fc16a2))
|
||||
* **tools/looker:** Add ability to set destination folder with `make_look` and `make_dashboard`. ([#2245](https://github.com/googleapis/genai-toolbox/issues/2245)) ([eb79339](https://github.com/googleapis/genai-toolbox/commit/eb793398cd1cc4006d9808ccda5dc7aea5e92bd5))
|
||||
* **tools/postgressql:** Add tool to list store procedure ([#2156](https://github.com/googleapis/genai-toolbox/issues/2156)) ([cf0fc51](https://github.com/googleapis/genai-toolbox/commit/cf0fc515b57d9b84770076f3c0c5597c4597ef62))
|
||||
* **tools/postgressql:** Add Parameter `embeddedBy` config support ([#2151](https://github.com/googleapis/genai-toolbox/issues/2151)) ([17b70cc](https://github.com/googleapis/genai-toolbox/commit/17b70ccaa754d15bcc33a1a3ecb7e652520fa600))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* **server:** Add `embeddingModel` config initialization ([#2281](https://github.com/googleapis/genai-toolbox/issues/2281)) ([a779975](https://github.com/googleapis/genai-toolbox/commit/a7799757c9345f99b6d2717841fbf792d364e1a2))
|
||||
* **sources/cloudgda:** Add import for cloudgda source ([#2217](https://github.com/googleapis/genai-toolbox/issues/2217)) ([7daa411](https://github.com/googleapis/genai-toolbox/commit/7daa4111f4ebfb0a35319fd67a8f7b9f0f99efcf))
|
||||
* **tools/alloydb-wait-for-operation:** Fix connection message generation ([#2228](https://github.com/googleapis/genai-toolbox/issues/2228)) ([7053fbb](https://github.com/googleapis/genai-toolbox/commit/7053fbb1953653143d39a8510916ea97a91022a6))
|
||||
* **tools/alloydbainl:** Only add psv when NL Config Param is defined ([#2265](https://github.com/googleapis/genai-toolbox/issues/2265)) ([ef8f3b0](https://github.com/googleapis/genai-toolbox/commit/ef8f3b02f2f38ce94a6ba9acf35d08b9469bef4e))
|
||||
* **tools/looker:** Looker client OAuth nil pointer error ([#2231](https://github.com/googleapis/genai-toolbox/issues/2231)) ([268700b](https://github.com/googleapis/genai-toolbox/commit/268700bdbf8281de0318d60ca613ed3672990b20))
|
||||
|
||||
## [0.24.0](https://github.com/googleapis/genai-toolbox/compare/v0.23.0...v0.24.0) (2025-12-19)
|
||||
|
||||
|
||||
|
||||
17
DEVELOPER.md
17
DEVELOPER.md
@@ -379,6 +379,23 @@ to approve PRs for main. TeamSync is used to create this team from the MDB
|
||||
Group `toolbox-contributors`. Googlers who are developing for MCP-Toolbox
|
||||
but aren't part of the core team should join this group.
|
||||
|
||||
### Issue/PR Triage and SLO
|
||||
After an issue is created, maintainers will assign the following labels:
|
||||
* `Priority` (defaulted to P0)
|
||||
* `Type` (if applicable)
|
||||
* `Product` (if applicable)
|
||||
|
||||
All incoming issues and PRs will follow the following SLO:
|
||||
| Type | Priority | Objective |
|
||||
|-----------------|----------|------------------------------------------------------------------------|
|
||||
| Feature Request | P0 | Must respond within **5 days** |
|
||||
| Process | P0 | Must respond within **5 days** |
|
||||
| Bugs | P0 | Must respond within **5 days**, and resolve/closure within **14 days** |
|
||||
| Bugs | P1 | Must respond within **7 days**, and resolve/closure within **90 days** |
|
||||
| Bugs | P2 | Must respond within **30 days**
|
||||
|
||||
_Types that are not listed in the table do not adhere to any SLO._
|
||||
|
||||
### Releasing
|
||||
|
||||
Toolbox has two types of releases: versioned and continuous. It uses Google
|
||||
|
||||
16
README.md
16
README.md
@@ -140,7 +140,7 @@ To install Toolbox as a binary:
|
||||
>
|
||||
> ```sh
|
||||
> # see releases page for other versions
|
||||
> export VERSION=0.24.0
|
||||
> export VERSION=0.25.0
|
||||
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox
|
||||
> chmod +x toolbox
|
||||
> ```
|
||||
@@ -153,7 +153,7 @@ To install Toolbox as a binary:
|
||||
>
|
||||
> ```sh
|
||||
> # see releases page for other versions
|
||||
> export VERSION=0.24.0
|
||||
> export VERSION=0.25.0
|
||||
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/arm64/toolbox
|
||||
> chmod +x toolbox
|
||||
> ```
|
||||
@@ -166,7 +166,7 @@ To install Toolbox as a binary:
|
||||
>
|
||||
> ```sh
|
||||
> # see releases page for other versions
|
||||
> export VERSION=0.24.0
|
||||
> export VERSION=0.25.0
|
||||
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/amd64/toolbox
|
||||
> chmod +x toolbox
|
||||
> ```
|
||||
@@ -179,7 +179,7 @@ To install Toolbox as a binary:
|
||||
>
|
||||
> ```cmd
|
||||
> :: see releases page for other versions
|
||||
> set VERSION=0.24.0
|
||||
> set VERSION=0.25.0
|
||||
> curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe"
|
||||
> ```
|
||||
>
|
||||
@@ -191,7 +191,7 @@ To install Toolbox as a binary:
|
||||
>
|
||||
> ```powershell
|
||||
> # see releases page for other versions
|
||||
> $VERSION = "0.24.0"
|
||||
> $VERSION = "0.25.0"
|
||||
> curl.exe -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v$VERSION/windows/amd64/toolbox.exe"
|
||||
> ```
|
||||
>
|
||||
@@ -204,7 +204,7 @@ You can also install Toolbox as a container:
|
||||
|
||||
```sh
|
||||
# see releases page for other versions
|
||||
export VERSION=0.24.0
|
||||
export VERSION=0.25.0
|
||||
docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION
|
||||
```
|
||||
|
||||
@@ -228,7 +228,7 @@ To install from source, ensure you have the latest version of
|
||||
[Go installed](https://go.dev/doc/install), and then run the following command:
|
||||
|
||||
```sh
|
||||
go install github.com/googleapis/genai-toolbox@v0.24.0
|
||||
go install github.com/googleapis/genai-toolbox@v0.25.0
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -272,7 +272,7 @@ To run Toolbox from binary:
|
||||
To run the server after pulling the [container image](#installing-the-server):
|
||||
|
||||
```sh
|
||||
export VERSION=0.11.0 # Use the version you pulled
|
||||
export VERSION=0.24.0 # Use the version you pulled
|
||||
docker run -p 5000:5000 \
|
||||
-v $(pwd)/tools.yaml:/app/tools.yaml \
|
||||
us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION \
|
||||
|
||||
@@ -381,7 +381,9 @@ func NewCommand(opts ...Option) *Command {
|
||||
flags.BoolVar(&cmd.cfg.Stdio, "stdio", false, "Listens via MCP STDIO instead of acting as a remote HTTP server.")
|
||||
flags.BoolVar(&cmd.cfg.DisableReload, "disable-reload", false, "Disables dynamic reloading of tools file.")
|
||||
flags.BoolVar(&cmd.cfg.UI, "ui", false, "Launches the Toolbox UI web server.")
|
||||
// TODO: Insecure by default. Might consider updating this for v1.0.0
|
||||
flags.StringSliceVar(&cmd.cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.")
|
||||
flags.StringSliceVar(&cmd.cfg.AllowedHosts, "allowed-hosts", []string{"*"}, "Specifies a list of hosts permitted to access this server. Defaults to '*'.")
|
||||
|
||||
// wrap RunE command so that we have access to original Command object
|
||||
cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) }
|
||||
@@ -947,6 +949,7 @@ func run(cmd *Command) error {
|
||||
|
||||
cmd.cfg.SourceConfigs = finalToolsFile.Sources
|
||||
cmd.cfg.AuthServiceConfigs = finalToolsFile.AuthServices
|
||||
cmd.cfg.EmbeddingModelConfigs = finalToolsFile.EmbeddingModels
|
||||
cmd.cfg.ToolConfigs = finalToolsFile.Tools
|
||||
cmd.cfg.ToolsetConfigs = finalToolsFile.Toolsets
|
||||
cmd.cfg.PromptConfigs = finalToolsFile.Prompts
|
||||
|
||||
@@ -67,6 +67,9 @@ func withDefaults(c server.ServerConfig) server.ServerConfig {
|
||||
if c.AllowedOrigins == nil {
|
||||
c.AllowedOrigins = []string{"*"}
|
||||
}
|
||||
if c.AllowedHosts == nil {
|
||||
c.AllowedHosts = []string{"*"}
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -220,6 +223,13 @@ func TestServerConfigFlags(t *testing.T) {
|
||||
AllowedOrigins: []string{"http://foo.com", "http://bar.com"},
|
||||
}),
|
||||
},
|
||||
{
|
||||
desc: "allowed hosts",
|
||||
args: []string{"--allowed-hosts", "http://foo.com,http://bar.com"},
|
||||
want: withDefaults(server.ServerConfig{
|
||||
AllowedHosts: []string{"http://foo.com", "http://bar.com"},
|
||||
}),
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
|
||||
@@ -1 +1 @@
|
||||
0.24.0
|
||||
0.25.0
|
||||
|
||||
18
docs/en/blogs/_index.md
Normal file
18
docs/en/blogs/_index.md
Normal file
@@ -0,0 +1,18 @@
|
||||
---
|
||||
title: "Featured Articles"
|
||||
weight: 3
|
||||
description: Toolbox Medium Blogs
|
||||
manualLink: "https://medium.com/@mcp_toolbox"
|
||||
manualLinkTarget: _blank
|
||||
---
|
||||
|
||||
<html>
|
||||
<head>
|
||||
<title>Redirecting to Featured Articles</title>
|
||||
<link rel="canonical" href="https://medium.com/@mcp_toolbox"/>
|
||||
<meta http-equiv="refresh" content="0;url=https://medium.com/@mcp_toolbox"/>
|
||||
</head>
|
||||
<body>
|
||||
<p>If you are not automatically redirected, please <a href="https://medium.com/@mcp_toolbox">follow this link to our articles</a>.</p>
|
||||
</body>
|
||||
</html>
|
||||
@@ -234,7 +234,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"version = \"0.24.0\" # x-release-please-version\n",
|
||||
"version = \"0.25.0\" # x-release-please-version\n",
|
||||
"! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
|
||||
"\n",
|
||||
"# Make the binary executable\n",
|
||||
|
||||
@@ -103,7 +103,7 @@ To install Toolbox as a binary on Linux (AMD64):
|
||||
|
||||
```sh
|
||||
# see releases page for other versions
|
||||
export VERSION=0.24.0
|
||||
export VERSION=0.25.0
|
||||
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox
|
||||
chmod +x toolbox
|
||||
```
|
||||
@@ -114,7 +114,7 @@ To install Toolbox as a binary on macOS (Apple Silicon):
|
||||
|
||||
```sh
|
||||
# see releases page for other versions
|
||||
export VERSION=0.24.0
|
||||
export VERSION=0.25.0
|
||||
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/arm64/toolbox
|
||||
chmod +x toolbox
|
||||
```
|
||||
@@ -125,7 +125,7 @@ To install Toolbox as a binary on macOS (Intel):
|
||||
|
||||
```sh
|
||||
# see releases page for other versions
|
||||
export VERSION=0.24.0
|
||||
export VERSION=0.25.0
|
||||
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/amd64/toolbox
|
||||
chmod +x toolbox
|
||||
```
|
||||
@@ -136,7 +136,7 @@ To install Toolbox as a binary on Windows (Command Prompt):
|
||||
|
||||
```cmd
|
||||
:: see releases page for other versions
|
||||
set VERSION=0.24.0
|
||||
set VERSION=0.25.0
|
||||
curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe"
|
||||
```
|
||||
|
||||
@@ -146,7 +146,7 @@ To install Toolbox as a binary on Windows (PowerShell):
|
||||
|
||||
```powershell
|
||||
# see releases page for other versions
|
||||
$VERSION = "0.24.0"
|
||||
$VERSION = "0.25.0"
|
||||
curl.exe -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v$VERSION/windows/amd64/toolbox.exe"
|
||||
```
|
||||
|
||||
@@ -158,7 +158,7 @@ You can also install Toolbox as a container:
|
||||
|
||||
```sh
|
||||
# see releases page for other versions
|
||||
export VERSION=0.24.0
|
||||
export VERSION=0.25.0
|
||||
docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION
|
||||
```
|
||||
|
||||
@@ -177,7 +177,7 @@ To install from source, ensure you have the latest version of
|
||||
[Go installed](https://go.dev/doc/install), and then run the following command:
|
||||
|
||||
```sh
|
||||
go install github.com/googleapis/genai-toolbox@v0.24.0
|
||||
go install github.com/googleapis/genai-toolbox@v0.25.0
|
||||
```
|
||||
|
||||
{{% /tab %}}
|
||||
|
||||
@@ -105,7 +105,7 @@ In this section, we will download Toolbox, configure our tools in a
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ In this section, we will download Toolbox, configure our tools in a
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
@@ -100,19 +100,19 @@ After you install Looker in the MCP Store, resources and tools from the server a
|
||||
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="linux/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/linux/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/darwin/arm64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/darwin/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="windows/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/windows/amd64/toolbox.exe
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -45,19 +45,19 @@ instance:
|
||||
<!-- {x-release-please-start-version} -->
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="linux/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/linux/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/darwin/arm64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/darwin/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="windows/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/windows/amd64/toolbox.exe
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -43,19 +43,19 @@ expose your developer assistant tools to a MySQL instance:
|
||||
<!-- {x-release-please-start-version} -->
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="linux/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/linux/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/darwin/arm64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/darwin/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="windows/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/windows/amd64/toolbox.exe
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -44,19 +44,19 @@ expose your developer assistant tools to a Neo4j instance:
|
||||
<!-- {x-release-please-start-version} -->
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="linux/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/linux/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/darwin/arm64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/darwin/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="windows/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/windows/amd64/toolbox.exe
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -56,19 +56,19 @@ Omni](https://cloud.google.com/alloydb/omni/current/docs/overview).
|
||||
<!-- {x-release-please-start-version} -->
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="linux/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/linux/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/darwin/arm64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/darwin/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="windows/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/windows/amd64/toolbox.exe
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -43,19 +43,19 @@ to expose your developer assistant tools to a SQLite instance:
|
||||
<!-- {x-release-please-start-version} -->
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="linux/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/linux/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/darwin/arm64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/darwin/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="windows/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/windows/amd64/toolbox.exe
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -68,7 +68,12 @@ networks:
|
||||
```
|
||||
|
||||
{{< notice tip >}}
|
||||
To prevent DNS rebinding attack, use the `--allowed-origins` flag to specify a
|
||||
To prevent DNS rebinding attack, use the `--allowed-hosts` flag to specify a
|
||||
list of hosts for validation. E.g. `command: [ "toolbox",
|
||||
"--tools-file", "/config/tools.yaml", "--address", "0.0.0.0",
|
||||
"--allowed-hosts", "localhost:5000"]`
|
||||
|
||||
To implement CORs, use the `--allowed-origins` flag to specify a
|
||||
list of origins permitted to access the server. E.g. `command: [ "toolbox",
|
||||
"--tools-file", "/config/tools.yaml", "--address", "0.0.0.0",
|
||||
"--allowed-origins", "https://foo.bar"]`
|
||||
|
||||
@@ -188,9 +188,13 @@ description: >
|
||||
path: tools.yaml
|
||||
```
|
||||
|
||||
{{< notice tip >}}
|
||||
{{< notice tip >}}
|
||||
To prevent DNS rebinding attack, use the `--allowed-origins` flag to specify a
|
||||
list of origins permitted to access the server. E.g. `args: ["--address",
|
||||
"0.0.0.0", "--allowed-hosts", "foo.bar:5000"]`
|
||||
|
||||
To implement CORs, use the `--allowed-origins` flag to specify a
|
||||
list of origins permitted to access the server. E.g. `args: ["--address",
|
||||
"0.0.0.0", "--allowed-origins", "https://foo.bar"]`
|
||||
{{< /notice >}}
|
||||
|
||||
|
||||
@@ -142,14 +142,18 @@ deployment will time out.
|
||||
|
||||
### Update deployed server to be secure
|
||||
|
||||
To prevent DNS rebinding attack, use the `--allowed-origins` flag to specify a
|
||||
list of origins permitted to access the server. In order to do that, you will
|
||||
To prevent DNS rebinding attack, use the `--allowed-hosts` flag to specify a
|
||||
list of hosts. In order to do that, you will
|
||||
have to re-deploy the cloud run service with the new flag.
|
||||
|
||||
To implement CORs checks, use the `--allowed-origins` flag to specify a list of
|
||||
origins permitted to access the server.
|
||||
|
||||
1. Set an environment variable to the cloud run url:
|
||||
|
||||
```bash
|
||||
export URL=<cloud run url>
|
||||
export HOST=<cloud run host>
|
||||
```
|
||||
|
||||
2. Redeploy Toolbox:
|
||||
@@ -160,7 +164,7 @@ have to re-deploy the cloud run service with the new flag.
|
||||
--service-account toolbox-identity \
|
||||
--region us-central1 \
|
||||
--set-secrets "/app/tools.yaml=tools:latest" \
|
||||
--args="--tools-file=/app/tools.yaml","--address=0.0.0.0","--port=8080","--allowed-origins=$URL"
|
||||
--args="--tools-file=/app/tools.yaml","--address=0.0.0.0","--port=8080","--allowed-origins=$URL","--allowed-hosts=$HOST"
|
||||
# --allow-unauthenticated # https://cloud.google.com/run/docs/authenticating/public#gcloud
|
||||
```
|
||||
|
||||
@@ -172,7 +176,7 @@ have to re-deploy the cloud run service with the new flag.
|
||||
--service-account toolbox-identity \
|
||||
--region us-central1 \
|
||||
--set-secrets "/app/tools.yaml=tools:latest" \
|
||||
--args="--tools-file=/app/tools.yaml","--address=0.0.0.0","--port=8080","--allowed-origins=$URL" \
|
||||
--args="--tools-file=/app/tools.yaml","--address=0.0.0.0","--port=8080","--allowed-origins=$URL","--allowed-hosts=$HOST" \
|
||||
# TODO(dev): update the following to match your VPC if necessary
|
||||
--network default \
|
||||
--subnet default
|
||||
|
||||
@@ -8,25 +8,26 @@ description: >
|
||||
|
||||
## Reference
|
||||
|
||||
| Flag (Short) | Flag (Long) | Description | Default |
|
||||
|--------------|----------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------|
|
||||
| `-a` | `--address` | Address of the interface the server will listen on. | `127.0.0.1` |
|
||||
| | `--disable-reload` | Disables dynamic reloading of tools file. | |
|
||||
| `-h` | `--help` | help for toolbox | |
|
||||
| | `--log-level` | Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'. | `info` |
|
||||
| | `--logging-format` | Specify logging format to use. Allowed: 'standard' or 'JSON'. | `standard` |
|
||||
| `-p` | `--port` | Port the server will listen on. | `5000` |
|
||||
| | `--prebuilt` | Use a prebuilt tool configuration by source type. See [Prebuilt Tools Reference](prebuilt-tools.md) for allowed values. | |
|
||||
| | `--stdio` | Listens via MCP STDIO instead of acting as a remote HTTP server. | |
|
||||
| | `--telemetry-gcp` | Enable exporting directly to Google Cloud Monitoring. | |
|
||||
| | `--telemetry-otlp` | Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318') | |
|
||||
| | `--telemetry-service-name` | Sets the value of the service.name resource attribute for telemetry data. | `toolbox` |
|
||||
| Flag (Short) | Flag (Long) | Description | Default |
|
||||
|--------------|----------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------|
|
||||
| `-a` | `--address` | Address of the interface the server will listen on. | `127.0.0.1` |
|
||||
| | `--disable-reload` | Disables dynamic reloading of tools file. | |
|
||||
| `-h` | `--help` | help for toolbox | |
|
||||
| | `--log-level` | Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'. | `info` |
|
||||
| | `--logging-format` | Specify logging format to use. Allowed: 'standard' or 'JSON'. | `standard` |
|
||||
| `-p` | `--port` | Port the server will listen on. | `5000` |
|
||||
| | `--prebuilt` | Use a prebuilt tool configuration by source type. See [Prebuilt Tools Reference](prebuilt-tools.md) for allowed values. | |
|
||||
| | `--stdio` | Listens via MCP STDIO instead of acting as a remote HTTP server. | |
|
||||
| | `--telemetry-gcp` | Enable exporting directly to Google Cloud Monitoring. | |
|
||||
| | `--telemetry-otlp` | Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318') | |
|
||||
| | `--telemetry-service-name` | Sets the value of the service.name resource attribute for telemetry data. | `toolbox` |
|
||||
| | `--tools-file` | File path specifying the tool configuration. Cannot be used with --tools-files or --tools-folder. | |
|
||||
| | `--tools-files` | Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --tools-file or --tools-folder. | |
|
||||
| | `--tools-folder` | Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --tools-file or --tools-files. | |
|
||||
| | `--ui` | Launches the Toolbox UI web server. | |
|
||||
| | `--allowed-origins` | Specifies a list of origins permitted to access this server. | `*` |
|
||||
| `-v` | `--version` | version for toolbox | |
|
||||
| | `--ui` | Launches the Toolbox UI web server. | |
|
||||
| | `--allowed-origins` | Specifies a list of origins permitted to access this server for CORs access. | `*` |
|
||||
| | `--allowed-hosts` | Specifies a list of hosts permitted to access this server to prevent DNS rebinding attacks. | `*` |
|
||||
| `-v` | `--version` | version for toolbox | |
|
||||
|
||||
## Examples
|
||||
|
||||
|
||||
@@ -134,6 +134,7 @@ sources:
|
||||
# scopes: # Optional: List of OAuth scopes to request.
|
||||
# - "https://www.googleapis.com/auth/bigquery"
|
||||
# - "https://www.googleapis.com/auth/drive.readonly"
|
||||
# maxQueryResultRows: 50 # Optional: Limits the number of rows returned by queries. Defaults to 50.
|
||||
```
|
||||
|
||||
Initialize a BigQuery source that uses the client's access token:
|
||||
@@ -153,6 +154,7 @@ sources:
|
||||
# scopes: # Optional: List of OAuth scopes to request.
|
||||
# - "https://www.googleapis.com/auth/bigquery"
|
||||
# - "https://www.googleapis.com/auth/drive.readonly"
|
||||
# maxQueryResultRows: 50 # Optional: Limits the number of rows returned by queries. Defaults to 50.
|
||||
```
|
||||
|
||||
## Reference
|
||||
@@ -167,3 +169,4 @@ sources:
|
||||
| useClientOAuth | bool | false | If true, forwards the client's OAuth access token from the "Authorization" header to downstream queries. **Note:** This cannot be used with `writeMode: protected`. |
|
||||
| scopes | []string | false | A list of OAuth 2.0 scopes to use for the credentials. If not provided, default scopes are used. |
|
||||
| impersonateServiceAccount | string | false | Service account email to impersonate when making BigQuery and Dataplex API calls. The authenticated principal must have the `roles/iam.serviceAccountTokenCreator` role on the target service account. [Learn More](https://cloud.google.com/iam/docs/service-account-impersonation) |
|
||||
| maxQueryResultRows | int | false | The maximum number of rows to return from a query. Defaults to 50. |
|
||||
|
||||
@@ -771,7 +771,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"version = \"0.24.0\" # x-release-please-version\n",
|
||||
"version = \"0.25.0\" # x-release-please-version\n",
|
||||
"! curl -L -o /content/toolbox https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
|
||||
"\n",
|
||||
"# Make the binary executable\n",
|
||||
|
||||
@@ -123,7 +123,7 @@ In this section, we will download and install the Toolbox binary.
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
export VERSION="0.24.0"
|
||||
export VERSION="0.25.0"
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v$VERSION/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -220,7 +220,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"version = \"0.24.0\" # x-release-please-version\n",
|
||||
"version = \"0.25.0\" # x-release-please-version\n",
|
||||
"! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
|
||||
"\n",
|
||||
"# Make the binary executable\n",
|
||||
|
||||
@@ -179,7 +179,7 @@ to use BigQuery, and then run the Toolbox server.
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ In this section, we will download Toolbox, configure our tools in a
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ In this section, we will download Toolbox and run the Toolbox server.
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ In this section, we will download Toolbox and run the Toolbox server.
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ In this section, we will download Toolbox and run the Toolbox server.
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.25.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "mcp-toolbox-for-databases",
|
||||
"version": "0.24.0",
|
||||
"version": "0.25.0",
|
||||
"description": "MCP Toolbox for Databases is an open-source MCP server for more than 30 different datasources.",
|
||||
"contextFileName": "MCP-TOOLBOX-EXTENSION.md"
|
||||
}
|
||||
@@ -19,6 +19,7 @@ sources:
|
||||
location: ${BIGQUERY_LOCATION:}
|
||||
useClientOAuth: ${BIGQUERY_USE_CLIENT_OAUTH:false}
|
||||
scopes: ${BIGQUERY_SCOPES:}
|
||||
maxQueryResultRows: ${BIGQUERY_MAX_QUERY_RESULT_ROWS:50}
|
||||
|
||||
tools:
|
||||
analyze_contribution:
|
||||
|
||||
@@ -68,6 +68,8 @@ type ServerConfig struct {
|
||||
UI bool
|
||||
// Specifies a list of origins permitted to access this server.
|
||||
AllowedOrigins []string
|
||||
// Specifies a list of hosts permitted to access this server
|
||||
AllowedHosts []string
|
||||
}
|
||||
|
||||
type logFormat string
|
||||
|
||||
@@ -300,6 +300,21 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
return sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
|
||||
}
|
||||
|
||||
func hostCheck(allowedHosts map[string]struct{}) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, hasWildcard := allowedHosts["*"]
|
||||
_, hostIsAllowed := allowedHosts[r.Host]
|
||||
if !hasWildcard && !hostIsAllowed {
|
||||
// Return 400 Bad Request or 403 Forbidden to block the attack
|
||||
http.Error(w, "Invalid Host header", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// NewServer returns a Server object based on provided Config.
|
||||
func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
|
||||
instrumentation, err := util.InstrumentationFromContext(ctx)
|
||||
@@ -374,7 +389,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
|
||||
|
||||
// cors
|
||||
if slices.Contains(cfg.AllowedOrigins, "*") {
|
||||
s.logger.WarnContext(ctx, "wildcard (`*`) allows all origin to access the resource and is not secure. Use it with cautious for public, non-sensitive data, or during local development. Recommended to use `--allowed-origins` flag to prevent DNS rebinding attacks")
|
||||
s.logger.WarnContext(ctx, "wildcard (`*`) allows all origin to access the resource and is not secure. Use it with cautious for public, non-sensitive data, or during local development. Recommended to use `--allowed-origins` flag")
|
||||
}
|
||||
corsOpts := cors.Options{
|
||||
AllowedOrigins: cfg.AllowedOrigins,
|
||||
@@ -385,6 +400,15 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
|
||||
MaxAge: 300, // cache preflight results for 5 minutes
|
||||
}
|
||||
r.Use(cors.Handler(corsOpts))
|
||||
// validate hosts for DNS rebinding attacks
|
||||
if slices.Contains(cfg.AllowedHosts, "*") {
|
||||
s.logger.WarnContext(ctx, "wildcard (`*`) allows all hosts to access the resource and is not secure. Use it with cautious for public, non-sensitive data, or during local development. Recommended to use `--allowed-hosts` flag to prevent DNS rebinding attacks")
|
||||
}
|
||||
allowedHostsMap := make(map[string]struct{}, len(cfg.AllowedHosts))
|
||||
for _, h := range cfg.AllowedHosts {
|
||||
allowedHostsMap[h] = struct{}{}
|
||||
}
|
||||
r.Use(hostCheck(allowedHostsMap))
|
||||
|
||||
// control plane
|
||||
apiR, err := apiRouter(s)
|
||||
|
||||
@@ -43,9 +43,10 @@ func TestServe(t *testing.T) {
|
||||
|
||||
addr, port := "127.0.0.1", 5000
|
||||
cfg := server.ServerConfig{
|
||||
Version: "0.0.0",
|
||||
Address: addr,
|
||||
Port: port,
|
||||
Version: "0.0.0",
|
||||
Address: addr,
|
||||
Port: port,
|
||||
AllowedHosts: []string{"*"},
|
||||
}
|
||||
|
||||
otelShutdown, err := telemetry.SetupOTel(ctx, "0.0.0", "", false, "toolbox")
|
||||
|
||||
@@ -89,6 +89,7 @@ type Config struct {
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
ImpersonateServiceAccount string `yaml:"impersonateServiceAccount"`
|
||||
Scopes StringOrStringSlice `yaml:"scopes"`
|
||||
MaxQueryResultRows int `yaml:"maxQueryResultRows"`
|
||||
}
|
||||
|
||||
// StringOrStringSlice is a custom type that can unmarshal both a single string
|
||||
@@ -127,6 +128,10 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
r.WriteMode = WriteModeAllowed
|
||||
}
|
||||
|
||||
if r.MaxQueryResultRows == 0 {
|
||||
r.MaxQueryResultRows = 50
|
||||
}
|
||||
|
||||
if r.WriteMode == WriteModeProtected && r.UseClientOAuth {
|
||||
// The protected mode only allows write operations to the session's temporary datasets.
|
||||
// when using client OAuth, a new session is created every
|
||||
@@ -150,7 +155,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
Client: client,
|
||||
RestService: restService,
|
||||
TokenSource: tokenSource,
|
||||
MaxQueryResultRows: 50,
|
||||
MaxQueryResultRows: r.MaxQueryResultRows,
|
||||
ClientCreator: clientCreator,
|
||||
}
|
||||
|
||||
@@ -567,7 +572,7 @@ func (s *Source) RunSQL(ctx context.Context, bqClient *bigqueryapi.Client, state
|
||||
}
|
||||
|
||||
var out []any
|
||||
for {
|
||||
for s.MaxQueryResultRows <= 0 || len(out) < s.MaxQueryResultRows {
|
||||
var val []bigqueryapi.Value
|
||||
err = it.Next(&val)
|
||||
if err == iterator.Done {
|
||||
|
||||
@@ -21,9 +21,12 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"go.opentelemetry.io/otel/trace/noop"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
)
|
||||
|
||||
func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
@@ -154,6 +157,26 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with max query result rows example",
|
||||
in: `
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
maxQueryResultRows: 10
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
MaxQueryResultRows: 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
@@ -220,6 +243,59 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitialize_MaxQueryResultRows(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
ctx = util.WithUserAgent(ctx, "test-agent")
|
||||
tracer := noop.NewTracerProvider().Tracer("")
|
||||
|
||||
tcs := []struct {
|
||||
desc string
|
||||
cfg bigquery.Config
|
||||
want int
|
||||
}{
|
||||
{
|
||||
desc: "default value",
|
||||
cfg: bigquery.Config{
|
||||
Name: "test-default",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "test-project",
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
want: 50,
|
||||
},
|
||||
{
|
||||
desc: "configured value",
|
||||
cfg: bigquery.Config{
|
||||
Name: "test-configured",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "test-project",
|
||||
UseClientOAuth: true,
|
||||
MaxQueryResultRows: 100,
|
||||
},
|
||||
want: 100,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
src, err := tc.cfg.Initialize(ctx, tracer)
|
||||
if err != nil {
|
||||
t.Fatalf("Initialize failed: %v", err)
|
||||
}
|
||||
bqSrc, ok := src.(*bigquery.Source)
|
||||
if !ok {
|
||||
t.Fatalf("Expected *bigquery.Source, got %T", src)
|
||||
}
|
||||
if bqSrc.MaxQueryResultRows != tc.want {
|
||||
t.Errorf("MaxQueryResultRows = %d, want %d", bqSrc.MaxQueryResultRows, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -16,8 +16,12 @@ package cloudhealthcare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -255,3 +259,299 @@ func (s *Source) IsDICOMStoreAllowed(storeID string) bool {
|
||||
func (s *Source) UseClientAuthorization() bool {
|
||||
return s.UseClientOAuth
|
||||
}
|
||||
|
||||
func parseResults(resp *http.Response) (any, error) {
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
var jsonMap map[string]interface{}
|
||||
if err := json.Unmarshal(respBytes, &jsonMap); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
|
||||
}
|
||||
return jsonMap, nil
|
||||
}
|
||||
|
||||
func (s *Source) getService(tokenStr string) (*healthcare.Service, error) {
|
||||
svc := s.Service()
|
||||
var err error
|
||||
// Initialize new service if using user OAuth token
|
||||
if s.UseClientAuthorization() {
|
||||
svc, err = s.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
func (s *Source) FHIRFetchPage(ctx context.Context, url, tokenStr string) (any, error) {
|
||||
var httpClient *http.Client
|
||||
if s.UseClientAuthorization() {
|
||||
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tokenStr})
|
||||
httpClient = oauth2.NewClient(ctx, ts)
|
||||
} else {
|
||||
// The source.Service() object holds a client with the default credentials.
|
||||
// However, the client is not exported, so we have to create a new one.
|
||||
var err error
|
||||
httpClient, err = google.DefaultClient(ctx, healthcare.CloudHealthcareScope)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create default http client: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create http request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/fhir+json;charset=utf-8")
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get fhir page from %q: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return parseResults(resp)
|
||||
}
|
||||
|
||||
func (s *Source) FHIRPatientEverything(storeID, patientID, tokenStr string, opts []googleapi.CallOption) (any, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", s.Project(), s.Region(), s.DatasetID(), storeID, patientID)
|
||||
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.PatientEverything(name).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call patient everything for %q: %w", name, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return parseResults(resp)
|
||||
}
|
||||
|
||||
func (s *Source) FHIRPatientSearch(storeID, tokenStr string, opts []googleapi.CallOption) (any, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.SearchType(name, "Patient", &healthcare.SearchResourcesRequest{ResourceType: "Patient"}).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search patient resources: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return parseResults(resp)
|
||||
}
|
||||
|
||||
func (s *Source) GetDataset(tokenStr string) (*healthcare.Dataset, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
|
||||
dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
}
|
||||
return dataset, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetFHIRResource(storeID, resType, resID, tokenStr string) (any, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", s.Project(), s.Region(), s.DatasetID(), storeID, resType, resID)
|
||||
call := svc.Projects.Locations.Datasets.FhirStores.Fhir.Read(name)
|
||||
call.Header().Set("Content-Type", "application/fhir+json;charset=utf-8")
|
||||
resp, err := call.Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get fhir resource %q: %w", name, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return parseResults(resp)
|
||||
}
|
||||
|
||||
func (s *Source) GetDICOMStore(storeID, tokenStr string) (*healthcare.DicomStore, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.DicomStores.Get(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get DICOM store %q: %w", storeName, err)
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetFHIRStore(storeID, tokenStr string) (*healthcare.FhirStore, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.FhirStores.Get(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get FHIR store %q: %w", storeName, err)
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetDICOMStoreMetrics(storeID, tokenStr string) (*healthcare.DicomStoreMetrics, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.DicomStores.GetDICOMStoreMetrics(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metrics for DICOM store %q: %w", storeName, err)
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetFHIRStoreMetrics(storeID, tokenStr string) (*healthcare.FhirStoreMetrics, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.FhirStores.GetFHIRStoreMetrics(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metrics for FHIR store %q: %w", storeName, err)
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *Source) ListDICOMStores(tokenStr string) ([]*healthcare.DicomStore, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
|
||||
stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
}
|
||||
var filtered []*healthcare.DicomStore
|
||||
for _, store := range stores.DicomStores {
|
||||
if len(s.AllowedDICOMStores()) == 0 {
|
||||
filtered = append(filtered, store)
|
||||
continue
|
||||
}
|
||||
if len(store.Name) == 0 {
|
||||
continue
|
||||
}
|
||||
parts := strings.Split(store.Name, "/")
|
||||
if _, ok := s.AllowedDICOMStores()[parts[len(parts)-1]]; ok {
|
||||
filtered = append(filtered, store)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
func (s *Source) ListFHIRStores(tokenStr string) ([]*healthcare.FhirStore, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
|
||||
stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
}
|
||||
var filtered []*healthcare.FhirStore
|
||||
for _, store := range stores.FhirStores {
|
||||
if len(s.AllowedFHIRStores()) == 0 {
|
||||
filtered = append(filtered, store)
|
||||
continue
|
||||
}
|
||||
if len(store.Name) == 0 {
|
||||
continue
|
||||
}
|
||||
parts := strings.Split(store.Name, "/")
|
||||
if _, ok := s.AllowedFHIRStores()[parts[len(parts)-1]]; ok {
|
||||
filtered = append(filtered, store)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
func (s *Source) RetrieveRenderedDICOMInstance(storeID, study, series, sop string, frame int, tokenStr string) (any, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
dicomWebPath := fmt.Sprintf("studies/%s/series/%s/instances/%s/frames/%d/rendered", study, series, sop, frame)
|
||||
call := svc.Projects.Locations.Datasets.DicomStores.Studies.Series.Instances.Frames.RetrieveRendered(name, dicomWebPath)
|
||||
call.Header().Set("Accept", "image/jpeg")
|
||||
resp, err := call.Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to retrieve dicom instance rendered image: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("RetrieveRendered: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
base64String := base64.StdEncoding.EncodeToString(respBytes)
|
||||
return base64String, nil
|
||||
}
|
||||
|
||||
func (s *Source) SearchDICOM(toolKind, storeID, dicomWebPath, tokenStr string, opts []googleapi.CallOption) (any, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
var resp *http.Response
|
||||
switch toolKind {
|
||||
case "cloud-healthcare-search-dicom-instances":
|
||||
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...)
|
||||
case "cloud-healthcare-search-dicom-series":
|
||||
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForSeries(name, dicomWebPath).Do(opts...)
|
||||
case "cloud-healthcare-search-dicom-studies":
|
||||
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForStudies(name, dicomWebPath).Do(opts...)
|
||||
default:
|
||||
return nil, fmt.Errorf("incompatible tool kind: %s", toolKind)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search dicom series: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
if len(respBytes) == 0 {
|
||||
return []interface{}{}, nil
|
||||
}
|
||||
var result []interface{}
|
||||
if err := json.Unmarshal(respBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -19,6 +19,8 @@ import (
|
||||
"fmt"
|
||||
|
||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
||||
"cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"github.com/cenkalti/backoff/v5"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
@@ -121,3 +123,101 @@ func initDataplexConnection(
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *Source) LookupEntry(ctx context.Context, name string, view int, aspectTypes []string, entry string) (*dataplexpb.Entry, error) {
|
||||
viewMap := map[int]dataplexpb.EntryView{
|
||||
1: dataplexpb.EntryView_BASIC,
|
||||
2: dataplexpb.EntryView_FULL,
|
||||
3: dataplexpb.EntryView_CUSTOM,
|
||||
4: dataplexpb.EntryView_ALL,
|
||||
}
|
||||
req := &dataplexpb.LookupEntryRequest{
|
||||
Name: name,
|
||||
View: viewMap[view],
|
||||
AspectTypes: aspectTypes,
|
||||
Entry: entry,
|
||||
}
|
||||
result, err := s.CatalogClient().LookupEntry(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Source) searchRequest(ctx context.Context, query string, pageSize int, orderBy string) (*dataplexapi.SearchEntriesResultIterator, error) {
|
||||
// Create SearchEntriesRequest with the provided parameters
|
||||
req := &dataplexpb.SearchEntriesRequest{
|
||||
Query: query,
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", s.ProjectID()),
|
||||
PageSize: int32(pageSize),
|
||||
OrderBy: orderBy,
|
||||
SemanticSearch: true,
|
||||
}
|
||||
|
||||
// Perform the search using the CatalogClient - this will return an iterator
|
||||
it := s.CatalogClient().SearchEntries(ctx, req)
|
||||
if it == nil {
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", s.ProjectID())
|
||||
}
|
||||
return it, nil
|
||||
}
|
||||
|
||||
func (s *Source) SearchAspectTypes(ctx context.Context, query string, pageSize int, orderBy string) ([]*dataplexpb.AspectType, error) {
|
||||
q := query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype"
|
||||
it, err := s.searchRequest(ctx, q, pageSize, orderBy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Iterate through the search results and call GetAspectType for each result using the resource name
|
||||
var results []*dataplexpb.AspectType
|
||||
for {
|
||||
entry, err := it.Next()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
// Create an instance of exponential backoff with default values for retrying GetAspectType calls
|
||||
// InitialInterval, RandomizationFactor, Multiplier, MaxInterval = 500 ms, 0.5, 1.5, 60 s
|
||||
getAspectBackOff := backoff.NewExponentialBackOff()
|
||||
|
||||
resourceName := entry.DataplexEntry.GetEntrySource().Resource
|
||||
getAspectTypeReq := &dataplexpb.GetAspectTypeRequest{
|
||||
Name: resourceName,
|
||||
}
|
||||
|
||||
operation := func() (*dataplexpb.AspectType, error) {
|
||||
aspectType, err := s.CatalogClient().GetAspectType(ctx, getAspectTypeReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err)
|
||||
}
|
||||
return aspectType, nil
|
||||
}
|
||||
|
||||
// Retry the GetAspectType operation with exponential backoff
|
||||
aspectType, err := backoff.Retry(ctx, operation, backoff.WithBackOff(getAspectBackOff))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get aspect type after retries for entry %q: %w", resourceName, err)
|
||||
}
|
||||
|
||||
results = append(results, aspectType)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *Source) SearchEntries(ctx context.Context, query string, pageSize int, orderBy string) ([]*dataplexpb.SearchEntriesResult, error) {
|
||||
it, err := s.searchRequest(ctx, query, pageSize, orderBy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var results []*dataplexpb.SearchEntriesResult
|
||||
for {
|
||||
entry, err := it.Next()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
results = append(results, entry)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
@@ -16,7 +16,10 @@ package firestore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/firestore"
|
||||
"github.com/goccy/go-yaml"
|
||||
@@ -25,6 +28,7 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/api/firebaserules/v1"
|
||||
"google.golang.org/api/option"
|
||||
"google.golang.org/genproto/googleapis/type/latlng"
|
||||
)
|
||||
|
||||
const SourceKind string = "firestore"
|
||||
@@ -113,6 +117,476 @@ func (s *Source) GetDatabaseId() string {
|
||||
return s.Database
|
||||
}
|
||||
|
||||
// FirestoreValueToJSON converts a Firestore value to a simplified JSON representation
|
||||
// This removes type information and returns plain values
|
||||
func FirestoreValueToJSON(value any) any {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case time.Time:
|
||||
return v.Format(time.RFC3339Nano)
|
||||
case *latlng.LatLng:
|
||||
return map[string]any{
|
||||
"latitude": v.Latitude,
|
||||
"longitude": v.Longitude,
|
||||
}
|
||||
case []byte:
|
||||
return base64.StdEncoding.EncodeToString(v)
|
||||
case []any:
|
||||
result := make([]any, len(v))
|
||||
for i, item := range v {
|
||||
result[i] = FirestoreValueToJSON(item)
|
||||
}
|
||||
return result
|
||||
case map[string]any:
|
||||
result := make(map[string]any)
|
||||
for k, val := range v {
|
||||
result[k] = FirestoreValueToJSON(val)
|
||||
}
|
||||
return result
|
||||
case *firestore.DocumentRef:
|
||||
return v.Path
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// BuildQuery constructs the Firestore query from parameters
|
||||
func (s *Source) BuildQuery(collectionPath string, filter firestore.EntityFilter, selectFields []string, field string, direction firestore.Direction, limit int, analyzeQuery bool) (*firestore.Query, error) {
|
||||
collection := s.FirestoreClient().Collection(collectionPath)
|
||||
query := collection.Query
|
||||
|
||||
// Process and apply filters if template is provided
|
||||
if filter != nil {
|
||||
query = query.WhereEntity(filter)
|
||||
}
|
||||
if len(selectFields) > 0 {
|
||||
query = query.Select(selectFields...)
|
||||
}
|
||||
if field != "" {
|
||||
query = query.OrderBy(field, direction)
|
||||
}
|
||||
query = query.Limit(limit)
|
||||
|
||||
// Apply analyze options if enabled
|
||||
if analyzeQuery {
|
||||
query = query.WithRunOptions(firestore.ExplainOptions{
|
||||
Analyze: true,
|
||||
})
|
||||
}
|
||||
|
||||
return &query, nil
|
||||
}
|
||||
|
||||
// QueryResult represents a document result from the query
|
||||
type QueryResult struct {
|
||||
ID string `json:"id"`
|
||||
Path string `json:"path"`
|
||||
Data map[string]any `json:"data"`
|
||||
CreateTime any `json:"createTime,omitempty"`
|
||||
UpdateTime any `json:"updateTime,omitempty"`
|
||||
ReadTime any `json:"readTime,omitempty"`
|
||||
}
|
||||
|
||||
// QueryResponse represents the full response including optional metrics
|
||||
type QueryResponse struct {
|
||||
Documents []QueryResult `json:"documents"`
|
||||
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
|
||||
}
|
||||
|
||||
// ExecuteQuery runs the query and formats the results
|
||||
func (s *Source) ExecuteQuery(ctx context.Context, query *firestore.Query, analyzeQuery bool) (any, error) {
|
||||
docIterator := query.Documents(ctx)
|
||||
docs, err := docIterator.GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute query: %w", err)
|
||||
}
|
||||
// Convert results to structured format
|
||||
results := make([]QueryResult, len(docs))
|
||||
for i, doc := range docs {
|
||||
results[i] = QueryResult{
|
||||
ID: doc.Ref.ID,
|
||||
Path: doc.Ref.Path,
|
||||
Data: doc.Data(),
|
||||
CreateTime: doc.CreateTime,
|
||||
UpdateTime: doc.UpdateTime,
|
||||
ReadTime: doc.ReadTime,
|
||||
}
|
||||
}
|
||||
|
||||
// Return with explain metrics if requested
|
||||
if analyzeQuery {
|
||||
explainMetrics, err := getExplainMetrics(docIterator)
|
||||
if err == nil && explainMetrics != nil {
|
||||
response := QueryResponse{
|
||||
Documents: results,
|
||||
ExplainMetrics: explainMetrics,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// getExplainMetrics extracts explain metrics from the query iterator
|
||||
func getExplainMetrics(docIterator *firestore.DocumentIterator) (map[string]any, error) {
|
||||
explainMetrics, err := docIterator.ExplainMetrics()
|
||||
if err != nil || explainMetrics == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metricsData := make(map[string]any)
|
||||
|
||||
// Add plan summary if available
|
||||
if explainMetrics.PlanSummary != nil {
|
||||
planSummary := make(map[string]any)
|
||||
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
|
||||
metricsData["planSummary"] = planSummary
|
||||
}
|
||||
|
||||
// Add execution stats if available
|
||||
if explainMetrics.ExecutionStats != nil {
|
||||
executionStats := make(map[string]any)
|
||||
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
|
||||
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
|
||||
|
||||
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
|
||||
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
|
||||
}
|
||||
|
||||
if explainMetrics.ExecutionStats.DebugStats != nil {
|
||||
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
|
||||
}
|
||||
|
||||
metricsData["executionStats"] = executionStats
|
||||
}
|
||||
|
||||
return metricsData, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetDocuments(ctx context.Context, documentPaths []string) ([]any, error) {
|
||||
// Create document references from paths
|
||||
docRefs := make([]*firestore.DocumentRef, len(documentPaths))
|
||||
for i, path := range documentPaths {
|
||||
docRefs[i] = s.FirestoreClient().Doc(path)
|
||||
}
|
||||
|
||||
// Get all documents
|
||||
snapshots, err := s.FirestoreClient().GetAll(ctx, docRefs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get documents: %w", err)
|
||||
}
|
||||
|
||||
// Convert snapshots to response data
|
||||
results := make([]any, len(snapshots))
|
||||
for i, snapshot := range snapshots {
|
||||
docData := make(map[string]any)
|
||||
docData["path"] = documentPaths[i]
|
||||
docData["exists"] = snapshot.Exists()
|
||||
|
||||
if snapshot.Exists() {
|
||||
docData["data"] = snapshot.Data()
|
||||
docData["createTime"] = snapshot.CreateTime
|
||||
docData["updateTime"] = snapshot.UpdateTime
|
||||
docData["readTime"] = snapshot.ReadTime
|
||||
}
|
||||
|
||||
results[i] = docData
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *Source) AddDocuments(ctx context.Context, collectionPath string, documentData any, returnData bool) (map[string]any, error) {
|
||||
// Get the collection reference
|
||||
collection := s.FirestoreClient().Collection(collectionPath)
|
||||
|
||||
// Add the document to the collection
|
||||
docRef, writeResult, err := collection.Add(ctx, documentData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add document: %w", err)
|
||||
}
|
||||
// Build the response
|
||||
response := map[string]any{
|
||||
"documentPath": docRef.Path,
|
||||
"createTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
|
||||
}
|
||||
// Add document data if requested
|
||||
if returnData {
|
||||
// Fetch the updated document to return the current state
|
||||
snapshot, err := docRef.Get(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
|
||||
}
|
||||
// Convert the document data back to simple JSON format
|
||||
simplifiedData := FirestoreValueToJSON(snapshot.Data())
|
||||
response["documentData"] = simplifiedData
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *Source) UpdateDocument(ctx context.Context, documentPath string, updates []firestore.Update, documentData any, returnData bool) (map[string]any, error) {
|
||||
// Get the document reference
|
||||
docRef := s.FirestoreClient().Doc(documentPath)
|
||||
|
||||
// Prepare update data
|
||||
var writeResult *firestore.WriteResult
|
||||
var writeErr error
|
||||
|
||||
if len(updates) > 0 {
|
||||
writeResult, writeErr = docRef.Update(ctx, updates)
|
||||
} else {
|
||||
writeResult, writeErr = docRef.Set(ctx, documentData, firestore.MergeAll)
|
||||
}
|
||||
|
||||
if writeErr != nil {
|
||||
return nil, fmt.Errorf("failed to update document: %w", writeErr)
|
||||
}
|
||||
|
||||
// Build the response
|
||||
response := map[string]any{
|
||||
"documentPath": docRef.Path,
|
||||
"updateTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
|
||||
}
|
||||
|
||||
// Add document data if requested
|
||||
if returnData {
|
||||
// Fetch the updated document to return the current state
|
||||
snapshot, err := docRef.Get(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
|
||||
}
|
||||
// Convert the document data to simple JSON format
|
||||
simplifiedData := FirestoreValueToJSON(snapshot.Data())
|
||||
response["documentData"] = simplifiedData
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *Source) DeleteDocuments(ctx context.Context, documentPaths []string) ([]any, error) {
|
||||
// Create a BulkWriter to handle multiple deletions efficiently
|
||||
bulkWriter := s.FirestoreClient().BulkWriter(ctx)
|
||||
|
||||
// Keep track of jobs for each document
|
||||
jobs := make([]*firestore.BulkWriterJob, len(documentPaths))
|
||||
|
||||
// Add all delete operations to the BulkWriter
|
||||
for i, path := range documentPaths {
|
||||
docRef := s.FirestoreClient().Doc(path)
|
||||
job, err := bulkWriter.Delete(docRef)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add delete operation for document %q: %w", path, err)
|
||||
}
|
||||
jobs[i] = job
|
||||
}
|
||||
|
||||
// End the BulkWriter to execute all operations
|
||||
bulkWriter.End()
|
||||
|
||||
// Collect results
|
||||
results := make([]any, len(documentPaths))
|
||||
for i, job := range jobs {
|
||||
docData := make(map[string]any)
|
||||
docData["path"] = documentPaths[i]
|
||||
|
||||
// Wait for the job to complete and get the result
|
||||
_, err := job.Results()
|
||||
if err != nil {
|
||||
docData["success"] = false
|
||||
docData["error"] = err.Error()
|
||||
} else {
|
||||
docData["success"] = true
|
||||
}
|
||||
|
||||
results[i] = docData
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *Source) ListCollections(ctx context.Context, parentPath string) ([]any, error) {
|
||||
var collectionRefs []*firestore.CollectionRef
|
||||
var err error
|
||||
if parentPath != "" {
|
||||
// List subcollections of the specified document
|
||||
docRef := s.FirestoreClient().Doc(parentPath)
|
||||
collectionRefs, err = docRef.Collections(ctx).GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list subcollections of document %q: %w", parentPath, err)
|
||||
}
|
||||
} else {
|
||||
// List root collections
|
||||
collectionRefs, err = s.FirestoreClient().Collections(ctx).GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list root collections: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert collection references to response data
|
||||
results := make([]any, len(collectionRefs))
|
||||
for i, collRef := range collectionRefs {
|
||||
collData := make(map[string]any)
|
||||
collData["id"] = collRef.ID
|
||||
collData["path"] = collRef.Path
|
||||
|
||||
// If this is a subcollection, include parent information
|
||||
if collRef.Parent != nil {
|
||||
collData["parent"] = collRef.Parent.Path
|
||||
}
|
||||
results[i] = collData
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetRules(ctx context.Context) (any, error) {
|
||||
// Get the latest release for Firestore
|
||||
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", s.GetProjectId(), s.GetDatabaseId())
|
||||
release, err := s.FirebaseRulesClient().Projects.Releases.Get(releaseName).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get latest Firestore release: %w", err)
|
||||
}
|
||||
|
||||
if release.RulesetName == "" {
|
||||
return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", s.GetProjectId(), s.GetDatabaseId())
|
||||
}
|
||||
|
||||
// Get the ruleset content
|
||||
ruleset, err := s.FirebaseRulesClient().Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get ruleset content: %w", err)
|
||||
}
|
||||
|
||||
if ruleset.Source == nil || len(ruleset.Source.Files) == 0 {
|
||||
return nil, fmt.Errorf("no rules files found in ruleset")
|
||||
}
|
||||
|
||||
return ruleset, nil
|
||||
}
|
||||
|
||||
// SourcePosition represents the location of an issue in the source
|
||||
type SourcePosition struct {
|
||||
FileName string `json:"fileName,omitempty"`
|
||||
Line int64 `json:"line"` // 1-based
|
||||
Column int64 `json:"column"` // 1-based
|
||||
CurrentOffset int64 `json:"currentOffset"` // 0-based, inclusive start
|
||||
EndOffset int64 `json:"endOffset"` // 0-based, exclusive end
|
||||
}
|
||||
|
||||
// Issue represents a validation issue in the rules
|
||||
type Issue struct {
|
||||
SourcePosition SourcePosition `json:"sourcePosition"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
}
|
||||
|
||||
// ValidationResult represents the result of rules validation
|
||||
type ValidationResult struct {
|
||||
Valid bool `json:"valid"`
|
||||
IssueCount int `json:"issueCount"`
|
||||
FormattedIssues string `json:"formattedIssues,omitempty"`
|
||||
RawIssues []Issue `json:"rawIssues,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Source) ValidateRules(ctx context.Context, sourceParam string) (any, error) {
|
||||
// Create test request
|
||||
testRequest := &firebaserules.TestRulesetRequest{
|
||||
Source: &firebaserules.Source{
|
||||
Files: []*firebaserules.File{
|
||||
{
|
||||
Name: "firestore.rules",
|
||||
Content: sourceParam,
|
||||
},
|
||||
},
|
||||
},
|
||||
// We don't need test cases for validation only
|
||||
TestSuite: &firebaserules.TestSuite{
|
||||
TestCases: []*firebaserules.TestCase{},
|
||||
},
|
||||
}
|
||||
// Call the test API
|
||||
projectName := fmt.Sprintf("projects/%s", s.GetProjectId())
|
||||
response, err := s.FirebaseRulesClient().Projects.Test(projectName, testRequest).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate rules: %w", err)
|
||||
}
|
||||
|
||||
// Process the response
|
||||
if len(response.Issues) == 0 {
|
||||
return ValidationResult{
|
||||
Valid: true,
|
||||
IssueCount: 0,
|
||||
FormattedIssues: "✓ No errors detected. Rules are valid.",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Convert issues to our format
|
||||
issues := make([]Issue, len(response.Issues))
|
||||
for i, issue := range response.Issues {
|
||||
issues[i] = Issue{
|
||||
Description: issue.Description,
|
||||
Severity: issue.Severity,
|
||||
SourcePosition: SourcePosition{
|
||||
FileName: issue.SourcePosition.FileName,
|
||||
Line: issue.SourcePosition.Line,
|
||||
Column: issue.SourcePosition.Column,
|
||||
CurrentOffset: issue.SourcePosition.CurrentOffset,
|
||||
EndOffset: issue.SourcePosition.EndOffset,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Format issues
|
||||
sourceLines := strings.Split(sourceParam, "\n")
|
||||
var formattedOutput []string
|
||||
|
||||
formattedOutput = append(formattedOutput, fmt.Sprintf("Found %d issue(s) in rules source:\n", len(issues)))
|
||||
|
||||
for _, issue := range issues {
|
||||
issueString := fmt.Sprintf("%s: %s [Ln %d, Col %d]",
|
||||
issue.Severity,
|
||||
issue.Description,
|
||||
issue.SourcePosition.Line,
|
||||
issue.SourcePosition.Column)
|
||||
|
||||
if issue.SourcePosition.Line > 0 {
|
||||
lineIndex := int(issue.SourcePosition.Line - 1) // 0-based index
|
||||
if lineIndex >= 0 && lineIndex < len(sourceLines) {
|
||||
errorLine := sourceLines[lineIndex]
|
||||
issueString += fmt.Sprintf("\n```\n%s", errorLine)
|
||||
|
||||
// Add carets if we have column and offset information
|
||||
if issue.SourcePosition.Column > 0 &&
|
||||
issue.SourcePosition.CurrentOffset >= 0 &&
|
||||
issue.SourcePosition.EndOffset > issue.SourcePosition.CurrentOffset {
|
||||
|
||||
startColumn := int(issue.SourcePosition.Column - 1) // 0-based
|
||||
errorTokenLength := int(issue.SourcePosition.EndOffset - issue.SourcePosition.CurrentOffset)
|
||||
|
||||
if startColumn >= 0 && errorTokenLength > 0 && startColumn <= len(errorLine) {
|
||||
padding := strings.Repeat(" ", startColumn)
|
||||
carets := strings.Repeat("^", errorTokenLength)
|
||||
issueString += fmt.Sprintf("\n%s%s", padding, carets)
|
||||
}
|
||||
}
|
||||
issueString += "\n```"
|
||||
}
|
||||
}
|
||||
|
||||
formattedOutput = append(formattedOutput, issueString)
|
||||
}
|
||||
|
||||
formattedIssues := strings.Join(formattedOutput, "\n\n")
|
||||
|
||||
return ValidationResult{
|
||||
Valid: false,
|
||||
IssueCount: len(issues),
|
||||
FormattedIssues: formattedIssues,
|
||||
RawIssues: issues,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func initFirestoreConnection(
|
||||
ctx context.Context,
|
||||
tracer trace.Tracer,
|
||||
|
||||
@@ -16,6 +16,7 @@ package firestore_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -128,3 +129,37 @@ func TestFailParseFromYamlFirestore(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirestoreValueToJSON_RoundTrip(t *testing.T) {
|
||||
// Test round-trip conversion
|
||||
original := map[string]any{
|
||||
"name": "Test",
|
||||
"count": int64(42),
|
||||
"price": 19.99,
|
||||
"active": true,
|
||||
"tags": []any{"tag1", "tag2"},
|
||||
"metadata": map[string]any{
|
||||
"created": time.Now(),
|
||||
},
|
||||
"nullField": nil,
|
||||
}
|
||||
|
||||
// Convert to JSON representation
|
||||
jsonRepresentation := firestore.FirestoreValueToJSON(original)
|
||||
|
||||
// Verify types are simplified
|
||||
jsonMap, ok := jsonRepresentation.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected map, got %T", jsonRepresentation)
|
||||
}
|
||||
|
||||
// Time should be converted to string
|
||||
metadata, ok := jsonMap["metadata"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("metadata should be a map, got %T", jsonMap["metadata"])
|
||||
}
|
||||
_, ok = metadata["created"].(string)
|
||||
if !ok {
|
||||
t.Errorf("created should be a string, got %T", metadata["created"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,9 @@ package http
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
@@ -143,3 +145,28 @@ func (s *Source) HttpQueryParams() map[string]string {
|
||||
func (s *Source) Client() *http.Client {
|
||||
return s.client
|
||||
}
|
||||
|
||||
func (s *Source) RunRequest(req *http.Request) (any, error) {
|
||||
// Make request and fetch response
|
||||
resp, err := s.Client().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making HTTP request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var body []byte
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var data any
|
||||
if err = json.Unmarshal(body, &data); err != nil {
|
||||
// if unable to unmarshal data, return result as string.
|
||||
return string(body), nil
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
@@ -16,15 +16,21 @@ package serverlessspark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
|
||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
longrunning "cloud.google.com/go/longrunning/autogen"
|
||||
"cloud.google.com/go/longrunning/autogen/longrunningpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/api/iterator"
|
||||
"google.golang.org/api/option"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
const SourceKind string = "serverless-spark"
|
||||
@@ -121,3 +127,168 @@ func (s *Source) Close() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Source) CancelOperation(ctx context.Context, operation string) (any, error) {
|
||||
req := &longrunningpb.CancelOperationRequest{
|
||||
Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", s.GetProject(), s.GetLocation(), operation),
|
||||
}
|
||||
client, err := s.GetOperationsClient(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get operations client: %w", err)
|
||||
}
|
||||
err = client.CancelOperation(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to cancel operation: %w", err)
|
||||
}
|
||||
return fmt.Sprintf("Cancelled [%s].", operation), nil
|
||||
}
|
||||
|
||||
func (s *Source) CreateBatch(ctx context.Context, batch *dataprocpb.Batch) (map[string]any, error) {
|
||||
req := &dataprocpb.CreateBatchRequest{
|
||||
Parent: fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation()),
|
||||
Batch: batch,
|
||||
}
|
||||
|
||||
client := s.GetBatchControllerClient()
|
||||
op, err := client.CreateBatch(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create batch: %w", err)
|
||||
}
|
||||
meta, err := op.Metadata()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get create batch op metadata: %w", err)
|
||||
}
|
||||
|
||||
projectID, location, batchID, err := ExtractBatchDetails(meta.GetBatch())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error extracting batch details from name %q: %v", meta.GetBatch(), err)
|
||||
}
|
||||
consoleUrl := BatchConsoleURL(projectID, location, batchID)
|
||||
logsUrl := BatchLogsURL(projectID, location, batchID, meta.GetCreateTime().AsTime(), time.Time{})
|
||||
|
||||
wrappedResult := map[string]any{
|
||||
"opMetadata": meta,
|
||||
"consoleUrl": consoleUrl,
|
||||
"logsUrl": logsUrl,
|
||||
}
|
||||
return wrappedResult, nil
|
||||
}
|
||||
|
||||
// ListBatchesResponse is the response from the list batches API.
|
||||
type ListBatchesResponse struct {
|
||||
Batches []Batch `json:"batches"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
}
|
||||
|
||||
// Batch represents a single batch job.
|
||||
type Batch struct {
|
||||
Name string `json:"name"`
|
||||
UUID string `json:"uuid"`
|
||||
State string `json:"state"`
|
||||
Creator string `json:"creator"`
|
||||
CreateTime string `json:"createTime"`
|
||||
Operation string `json:"operation"`
|
||||
ConsoleURL string `json:"consoleUrl"`
|
||||
LogsURL string `json:"logsUrl"`
|
||||
}
|
||||
|
||||
func (s *Source) ListBatches(ctx context.Context, ps *int, pt, filter string) (any, error) {
|
||||
client := s.GetBatchControllerClient()
|
||||
parent := fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation())
|
||||
req := &dataprocpb.ListBatchesRequest{
|
||||
Parent: parent,
|
||||
OrderBy: "create_time desc",
|
||||
}
|
||||
|
||||
if ps != nil {
|
||||
req.PageSize = int32(*ps)
|
||||
}
|
||||
if pt != "" {
|
||||
req.PageToken = pt
|
||||
}
|
||||
if filter != "" {
|
||||
req.Filter = filter
|
||||
}
|
||||
|
||||
it := client.ListBatches(ctx, req)
|
||||
pager := iterator.NewPager(it, int(req.PageSize), req.PageToken)
|
||||
|
||||
var batchPbs []*dataprocpb.Batch
|
||||
nextPageToken, err := pager.NextPage(&batchPbs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list batches: %w", err)
|
||||
}
|
||||
|
||||
batches, err := ToBatches(batchPbs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ListBatchesResponse{Batches: batches, NextPageToken: nextPageToken}, nil
|
||||
}
|
||||
|
||||
// ToBatches converts a slice of protobuf Batch messages to a slice of Batch structs.
|
||||
func ToBatches(batchPbs []*dataprocpb.Batch) ([]Batch, error) {
|
||||
batches := make([]Batch, 0, len(batchPbs))
|
||||
for _, batchPb := range batchPbs {
|
||||
consoleUrl, err := BatchConsoleURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating console url: %v", err)
|
||||
}
|
||||
logsUrl, err := BatchLogsURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating logs url: %v", err)
|
||||
}
|
||||
batch := Batch{
|
||||
Name: batchPb.Name,
|
||||
UUID: batchPb.Uuid,
|
||||
State: batchPb.State.Enum().String(),
|
||||
Creator: batchPb.Creator,
|
||||
CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339),
|
||||
Operation: batchPb.Operation,
|
||||
ConsoleURL: consoleUrl,
|
||||
LogsURL: logsUrl,
|
||||
}
|
||||
batches = append(batches, batch)
|
||||
}
|
||||
return batches, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetBatch(ctx context.Context, name string) (map[string]any, error) {
|
||||
client := s.GetBatchControllerClient()
|
||||
req := &dataprocpb.GetBatchRequest{
|
||||
Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", s.GetProject(), s.GetLocation(), name),
|
||||
}
|
||||
|
||||
batchPb, err := client.GetBatch(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get batch: %w", err)
|
||||
}
|
||||
|
||||
jsonBytes, err := protojson.Marshal(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal batch to JSON: %w", err)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(jsonBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal batch JSON: %w", err)
|
||||
}
|
||||
|
||||
consoleUrl, err := BatchConsoleURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating console url: %v", err)
|
||||
}
|
||||
logsUrl, err := BatchLogsURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating logs url: %v", err)
|
||||
}
|
||||
|
||||
wrappedResult := map[string]any{
|
||||
"consoleUrl": consoleUrl,
|
||||
"logsUrl": logsUrl,
|
||||
"batch": result,
|
||||
}
|
||||
|
||||
return wrappedResult, nil
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
// Copyright 2025 Google LLC
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package common
|
||||
package serverlessspark
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -23,13 +23,13 @@ import (
|
||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
)
|
||||
|
||||
var batchFullNameRegex = regexp.MustCompile(`projects/(?P<project>[^/]+)/locations/(?P<location>[^/]+)/batches/(?P<batch_id>[^/]+)`)
|
||||
|
||||
const (
|
||||
logTimeBufferBefore = 1 * time.Minute
|
||||
logTimeBufferAfter = 10 * time.Minute
|
||||
)
|
||||
|
||||
var batchFullNameRegex = regexp.MustCompile(`projects/(?P<project>[^/]+)/locations/(?P<location>[^/]+)/batches/(?P<batch_id>[^/]+)`)
|
||||
|
||||
// Extract BatchDetails extracts the project ID, location, and batch ID from a fully qualified batch name.
|
||||
func ExtractBatchDetails(batchName string) (projectID, location, batchID string, err error) {
|
||||
matches := batchFullNameRegex.FindStringSubmatch(batchName)
|
||||
@@ -39,26 +39,6 @@ func ExtractBatchDetails(batchName string) (projectID, location, batchID string,
|
||||
return matches[1], matches[2], matches[3], nil
|
||||
}
|
||||
|
||||
// BatchConsoleURLFromProto builds a URL to the Google Cloud Console linking to the batch summary page.
|
||||
func BatchConsoleURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
|
||||
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return BatchConsoleURL(projectID, location, batchID), nil
|
||||
}
|
||||
|
||||
// BatchLogsURLFromProto builds a URL to the Google Cloud Console showing Cloud Logging for the given batch and time range.
|
||||
func BatchLogsURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
|
||||
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
createTime := batchPb.GetCreateTime().AsTime()
|
||||
stateTime := batchPb.GetStateTime().AsTime()
|
||||
return BatchLogsURL(projectID, location, batchID, createTime, stateTime), nil
|
||||
}
|
||||
|
||||
// BatchConsoleURL builds a URL to the Google Cloud Console linking to the batch summary page.
|
||||
func BatchConsoleURL(projectID, location, batchID string) string {
|
||||
return fmt.Sprintf("https://console.cloud.google.com/dataproc/batches/%s/%s/summary?project=%s", location, batchID, projectID)
|
||||
@@ -89,3 +69,23 @@ resource.labels.batch_id="%s"`
|
||||
|
||||
return "https://console.cloud.google.com/logs/viewer?" + v.Encode()
|
||||
}
|
||||
|
||||
// BatchConsoleURLFromProto builds a URL to the Google Cloud Console linking to the batch summary page.
|
||||
func BatchConsoleURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
|
||||
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return BatchConsoleURL(projectID, location, batchID), nil
|
||||
}
|
||||
|
||||
// BatchLogsURLFromProto builds a URL to the Google Cloud Console showing Cloud Logging for the given batch and time range.
|
||||
func BatchLogsURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
|
||||
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
createTime := batchPb.GetCreateTime().AsTime()
|
||||
stateTime := batchPb.GetStateTime().AsTime()
|
||||
return BatchLogsURL(projectID, location, batchID, createTime, stateTime), nil
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
// Copyright 2025 Google LLC
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -12,19 +12,20 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package common
|
||||
package serverlessspark_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
func TestExtractBatchDetails_Success(t *testing.T) {
|
||||
batchName := "projects/my-project/locations/us-central1/batches/my-batch"
|
||||
projectID, location, batchID, err := ExtractBatchDetails(batchName)
|
||||
projectID, location, batchID, err := serverlessspark.ExtractBatchDetails(batchName)
|
||||
if err != nil {
|
||||
t.Errorf("ExtractBatchDetails() error = %v, want no error", err)
|
||||
return
|
||||
@@ -45,7 +46,7 @@ func TestExtractBatchDetails_Success(t *testing.T) {
|
||||
|
||||
func TestExtractBatchDetails_Failure(t *testing.T) {
|
||||
batchName := "invalid-name"
|
||||
_, _, _, err := ExtractBatchDetails(batchName)
|
||||
_, _, _, err := serverlessspark.ExtractBatchDetails(batchName)
|
||||
wantErr := "failed to parse batch name: invalid-name"
|
||||
if err == nil || err.Error() != wantErr {
|
||||
t.Errorf("ExtractBatchDetails() error = %v, want %v", err, wantErr)
|
||||
@@ -53,7 +54,7 @@ func TestExtractBatchDetails_Failure(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBatchConsoleURL(t *testing.T) {
|
||||
got := BatchConsoleURL("my-project", "us-central1", "my-batch")
|
||||
got := serverlessspark.BatchConsoleURL("my-project", "us-central1", "my-batch")
|
||||
want := "https://console.cloud.google.com/dataproc/batches/us-central1/my-batch/summary?project=my-project"
|
||||
if got != want {
|
||||
t.Errorf("BatchConsoleURL() = %v, want %v", got, want)
|
||||
@@ -63,7 +64,7 @@ func TestBatchConsoleURL(t *testing.T) {
|
||||
func TestBatchLogsURL(t *testing.T) {
|
||||
startTime := time.Date(2025, 10, 1, 5, 0, 0, 0, time.UTC)
|
||||
endTime := time.Date(2025, 10, 1, 6, 0, 0, 0, time.UTC)
|
||||
got := BatchLogsURL("my-project", "us-central1", "my-batch", startTime, endTime)
|
||||
got := serverlessspark.BatchLogsURL("my-project", "us-central1", "my-batch", startTime, endTime)
|
||||
want := "https://console.cloud.google.com/logs/viewer?advancedFilter=" +
|
||||
"resource.type%3D%22cloud_dataproc_batch%22" +
|
||||
"%0Aresource.labels.project_id%3D%22my-project%22" +
|
||||
@@ -82,7 +83,7 @@ func TestBatchConsoleURLFromProto(t *testing.T) {
|
||||
batchPb := &dataprocpb.Batch{
|
||||
Name: "projects/my-project/locations/us-central1/batches/my-batch",
|
||||
}
|
||||
got, err := BatchConsoleURLFromProto(batchPb)
|
||||
got, err := serverlessspark.BatchConsoleURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
t.Fatalf("BatchConsoleURLFromProto() error = %v", err)
|
||||
}
|
||||
@@ -100,7 +101,7 @@ func TestBatchLogsURLFromProto(t *testing.T) {
|
||||
CreateTime: timestamppb.New(createTime),
|
||||
StateTime: timestamppb.New(stateTime),
|
||||
}
|
||||
got, err := BatchLogsURLFromProto(batchPb)
|
||||
got, err := serverlessspark.BatchLogsURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
t.Fatalf("BatchLogsURLFromProto() error = %v", err)
|
||||
}
|
||||
@@ -77,26 +77,21 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
placeholderParts = append(placeholderParts, fmt.Sprintf("$%d", i+3)) // $1, $2 reserved
|
||||
}
|
||||
|
||||
var paramNamesSQL string
|
||||
var paramValuesSQL string
|
||||
|
||||
var stmt string
|
||||
if numParams > 0 {
|
||||
paramNamesSQL = fmt.Sprintf("ARRAY[%s]", strings.Join(quotedNameParts, ", "))
|
||||
paramValuesSQL = fmt.Sprintf("ARRAY[%s]", strings.Join(placeholderParts, ", "))
|
||||
paramNamesSQL := fmt.Sprintf("ARRAY[%s]", strings.Join(quotedNameParts, ", "))
|
||||
paramValuesSQL := fmt.Sprintf("ARRAY[%s]", strings.Join(placeholderParts, ", "))
|
||||
// execute_nl_query is the AlloyDB AI function that executes the natural language query
|
||||
// The first parameter is the natural language query, which is passed as $1
|
||||
// The second parameter is the NLConfig, which is passed as a $2
|
||||
// The following params are the list of PSV values passed to the NLConfig
|
||||
// Example SQL statement being executed:
|
||||
// SELECT alloydb_ai_nl.execute_nl_query(nl_question => 'How many tickets do I have?', nl_config_id => 'cymbal_air_nl_config', param_names => ARRAY ['user_email'], param_values => ARRAY ['hailongli@google.com']);
|
||||
stmtFormat := "SELECT alloydb_ai_nl.execute_nl_query(nl_question => $1, nl_config_id => $2, param_names => %s, param_values => %s);"
|
||||
stmt = fmt.Sprintf(stmtFormat, paramNamesSQL, paramValuesSQL)
|
||||
} else {
|
||||
paramNamesSQL = "ARRAY[]::TEXT[]"
|
||||
paramValuesSQL = "ARRAY[]::TEXT[]"
|
||||
stmt = "SELECT alloydb_ai_nl.execute_nl_query(nl_question => $1, nl_config_id => $2);"
|
||||
}
|
||||
|
||||
// execute_nl_query is the AlloyDB AI function that executes the natural language query
|
||||
// The first parameter is the natural language query, which is passed as $1
|
||||
// The second parameter is the NLConfig, which is passed as a $2
|
||||
// The following params are the list of PSV values passed to the NLConfig
|
||||
// Example SQL statement being executed:
|
||||
// SELECT alloydb_ai_nl.execute_nl_query(nl_question => 'How many tickets do I have?', nl_config_id => 'cymbal_air_nl_config', param_names => ARRAY ['user_email'], param_values => ARRAY ['hailongli@google.com']);
|
||||
stmtFormat := "SELECT alloydb_ai_nl.execute_nl_query(nl_question => $1, nl_config_id => $2, param_names => %s, param_values => %s);"
|
||||
stmt := fmt.Sprintf(stmtFormat, paramNamesSQL, paramValuesSQL)
|
||||
|
||||
newQuestionParam := parameters.NewStringParameter(
|
||||
"question", // name
|
||||
"The natural language question to ask.", // description
|
||||
|
||||
@@ -16,22 +16,13 @@ package fhirfetchpage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
const kind string = "cloud-healthcare-fhir-fetch-page"
|
||||
@@ -54,13 +45,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedFHIRStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
FHIRFetchPage(context.Context, string, string) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -118,48 +104,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey)
|
||||
}
|
||||
|
||||
var httpClient *http.Client
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tokenStr})
|
||||
httpClient = oauth2.NewClient(ctx, ts)
|
||||
} else {
|
||||
// The source.Service() object holds a client with the default credentials.
|
||||
// However, the client is not exported, so we have to create a new one.
|
||||
var err error
|
||||
httpClient, err = google.DefaultClient(ctx, healthcare.CloudHealthcareScope)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create default http client: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create http request: %w", err)
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/fhir+json;charset=utf-8")
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get fhir page from %q: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("read: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
var jsonMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(string(respBytes)), &jsonMap); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
|
||||
}
|
||||
return jsonMap, nil
|
||||
return source.FHIRFetchPage(ctx, url, tokenStr)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -16,20 +16,16 @@ package fhirpatienteverything
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/googleapi"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-healthcare-fhir-patient-everything"
|
||||
@@ -54,13 +50,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedFHIRStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
FHIRPatientEverything(string, string, string, []googleapi.CallOption) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -139,20 +131,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey)
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", source.Project(), source.Region(), source.DatasetID(), storeID, patientID)
|
||||
var opts []googleapi.CallOption
|
||||
if val, ok := params.AsMap()[typeFilterKey]; ok {
|
||||
types, ok := val.([]any)
|
||||
@@ -176,25 +159,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
opts = append(opts, googleapi.QueryParameter("_since", sinceStr))
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.PatientEverything(name).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call patient everything for %q: %w", name, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("patient-everything: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
var jsonMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(string(respBytes)), &jsonMap); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
|
||||
}
|
||||
return jsonMap, nil
|
||||
return source.FHIRPatientEverything(storeID, patientID, tokenStr, opts)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -16,20 +16,16 @@ package fhirpatientsearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/googleapi"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-healthcare-fhir-patient-search"
|
||||
@@ -70,13 +66,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedFHIRStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
FHIRPatientSearch(string, string, []googleapi.CallOption) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -169,17 +161,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
|
||||
var summary bool
|
||||
@@ -248,26 +232,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if summary {
|
||||
opts = append(opts, googleapi.QueryParameter("_summary", "text"))
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.SearchType(name, "Patient", &healthcare.SearchResourcesRequest{ResourceType: "Patient"}).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search patient resources: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
var jsonMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(string(respBytes)), &jsonMap); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
|
||||
}
|
||||
return jsonMap, nil
|
||||
return source.FHIRPatientSearch(storeID, tokenStr, opts)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
@@ -44,12 +43,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
GetDataset(string) (*healthcare.Dataset, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -100,27 +95,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
|
||||
// Initialize new service if using user OAuth token
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
|
||||
dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do()
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
return dataset, nil
|
||||
return source.GetDataset(tokenStr)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -45,13 +44,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedDICOMStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
GetDICOMStore(string, string) (*healthcare.DicomStore, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -117,31 +112,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.DicomStores.Get(storeName).Do()
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get DICOM store %q: %w", storeName, err)
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
return store, nil
|
||||
return source.GetDICOMStore(storeID, tokenStr)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -45,13 +44,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedDICOMStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
GetDICOMStoreMetrics(string, string) (*healthcare.DicomStoreMetrics, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -117,31 +112,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.DicomStores.GetDICOMStoreMetrics(storeName).Do()
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metrics for DICOM store %q: %w", storeName, err)
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
return store, nil
|
||||
return source.GetDICOMStoreMetrics(storeID, tokenStr)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -16,18 +16,14 @@ package getfhirresource
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-healthcare-get-fhir-resource"
|
||||
@@ -51,13 +47,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedFHIRStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
GetFHIRResource(string, string, string, string) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -134,46 +126,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", typeKey)
|
||||
}
|
||||
|
||||
resID, ok := params.AsMap()[idKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey)
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", source.Project(), source.Region(), source.DatasetID(), storeID, resType, resID)
|
||||
call := svc.Projects.Locations.Datasets.FhirStores.Fhir.Read(name)
|
||||
call.Header().Set("Content-Type", "application/fhir+json;charset=utf-8")
|
||||
resp, err := call.Do()
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get fhir resource %q: %w", name, err)
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("read: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
var jsonMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(string(respBytes)), &jsonMap); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
|
||||
}
|
||||
return jsonMap, nil
|
||||
return source.GetFHIRResource(storeID, resType, resID, tokenStr)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -45,13 +44,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedFHIRStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
GetFHIRStore(string, string) (*healthcare.FhirStore, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -117,31 +112,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.FhirStores.Get(storeName).Do()
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get FHIR store %q: %w", storeName, err)
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
return store, nil
|
||||
return source.GetFHIRStore(storeID, tokenStr)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -45,13 +44,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedFHIRStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
GetFHIRStoreMetrics(string, string) (*healthcare.FhirStoreMetrics, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -117,31 +112,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.FhirStores.GetFHIRStoreMetrics(storeName).Do()
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metrics for FHIR store %q: %w", storeName, err)
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
return store, nil
|
||||
return source.GetFHIRStoreMetrics(storeID, tokenStr)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,12 +17,10 @@ package listdicomstores
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
@@ -45,13 +43,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedDICOMStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
ListDICOMStores(tokenStr string) ([]*healthcare.DicomStore, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -102,41 +95,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
|
||||
// Initialize new service if using user OAuth token
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
|
||||
stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do()
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
var filtered []*healthcare.DicomStore
|
||||
for _, store := range stores.DicomStores {
|
||||
if len(source.AllowedDICOMStores()) == 0 {
|
||||
filtered = append(filtered, store)
|
||||
continue
|
||||
}
|
||||
if len(store.Name) == 0 {
|
||||
continue
|
||||
}
|
||||
parts := strings.Split(store.Name, "/")
|
||||
if _, ok := source.AllowedDICOMStores()[parts[len(parts)-1]]; ok {
|
||||
filtered = append(filtered, store)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
return source.ListDICOMStores(tokenStr)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,12 +17,10 @@ package listfhirstores
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
@@ -45,13 +43,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedFHIRStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
ListFHIRStores(string) ([]*healthcare.FhirStore, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -102,41 +95,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
|
||||
// Initialize new service if using user OAuth token
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
|
||||
stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do()
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
var filtered []*healthcare.FhirStore
|
||||
for _, store := range stores.FhirStores {
|
||||
if len(source.AllowedFHIRStores()) == 0 {
|
||||
filtered = append(filtered, store)
|
||||
continue
|
||||
}
|
||||
if len(store.Name) == 0 {
|
||||
continue
|
||||
}
|
||||
parts := strings.Split(store.Name, "/")
|
||||
if _, ok := source.AllowedFHIRStores()[parts[len(parts)-1]]; ok {
|
||||
filtered = append(filtered, store)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
return source.ListFHIRStores(tokenStr)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -16,18 +16,14 @@ package retrieverendereddicominstance
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-healthcare-retrieve-rendered-dicom-instance"
|
||||
@@ -53,13 +49,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedDICOMStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
RetrieveRenderedDICOMInstance(string, string, string, string, int, string) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -135,20 +127,10 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
|
||||
study, ok := params.AsMap()[studyInstanceUIDKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid '%s' parameter; expected a string", studyInstanceUIDKey)
|
||||
@@ -165,25 +147,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid '%s' parameter; expected an integer", frameNumberKey)
|
||||
}
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
dicomWebPath := fmt.Sprintf("studies/%s/series/%s/instances/%s/frames/%d/rendered", study, series, sop, frame)
|
||||
call := svc.Projects.Locations.Datasets.DicomStores.Studies.Series.Instances.Frames.RetrieveRendered(name, dicomWebPath)
|
||||
call.Header().Set("Accept", "image/jpeg")
|
||||
resp, err := call.Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to retrieve dicom instance rendered image: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("RetrieveRendered: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
base64String := base64.StdEncoding.EncodeToString(respBytes)
|
||||
return base64String, nil
|
||||
return source.RetrieveRenderedDICOMInstance(storeID, study, series, sop, frame, tokenStr)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -16,20 +16,16 @@ package searchdicominstances
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/googleapi"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-healthcare-search-dicom-instances"
|
||||
@@ -60,13 +56,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedDICOMStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
SearchDICOM(string, string, string, string, []googleapi.CallOption) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -144,23 +136,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
|
||||
opts, err := common.ParseDICOMSearchParameters(params, []string{sopInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})
|
||||
@@ -191,29 +173,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search dicom instances: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
if len(respBytes) == 0 {
|
||||
return []interface{}{}, nil
|
||||
}
|
||||
var result []interface{}
|
||||
if err := json.Unmarshal([]byte(string(respBytes)), &result); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
return source.SearchDICOM(t.Kind, storeID, dicomWebPath, tokenStr, opts)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -16,18 +16,15 @@ package searchdicomseries
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
"google.golang.org/api/googleapi"
|
||||
)
|
||||
|
||||
const kind string = "cloud-healthcare-search-dicom-series"
|
||||
@@ -57,13 +54,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedDICOMStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
SearchDICOM(string, string, string, string, []googleapi.CallOption) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -145,18 +138,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
|
||||
opts, err := common.ParseDICOMSearchParameters(params, []string{seriesInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})
|
||||
@@ -174,29 +158,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
dicomWebPath = fmt.Sprintf("studies/%s/series", id)
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForSeries(name, dicomWebPath).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search dicom series: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
if len(respBytes) == 0 {
|
||||
return []interface{}{}, nil
|
||||
}
|
||||
var result []interface{}
|
||||
if err := json.Unmarshal([]byte(string(respBytes)), &result); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
return source.SearchDICOM(t.Kind, storeID, dicomWebPath, tokenStr, opts)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -16,18 +16,15 @@ package searchdicomstudies
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
"google.golang.org/api/googleapi"
|
||||
)
|
||||
|
||||
const kind string = "cloud-healthcare-search-dicom-studies"
|
||||
@@ -55,13 +52,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedDICOMStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
SearchDICOM(string, string, string, string, []googleapi.CallOption) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -136,51 +129,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
|
||||
opts, err := common.ParseDICOMSearchParameters(params, []string{studyInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForStudies(name, "studies").Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search dicom studies: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
if len(respBytes) == 0 {
|
||||
return []interface{}{}, nil
|
||||
}
|
||||
var result []interface{}
|
||||
if err := json.Unmarshal([]byte(string(respBytes)), &result); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
dicomWebPath := "studies"
|
||||
return source.SearchDICOM(t.Kind, storeID, dicomWebPath, tokenStr, opts)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
||||
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
@@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
CatalogClient() *dataplexapi.CatalogClient
|
||||
LookupEntry(context.Context, string, int, []string, string) (*dataplexpb.Entry, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -118,12 +117,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
viewMap := map[int]dataplexpb.EntryView{
|
||||
1: dataplexpb.EntryView_BASIC,
|
||||
2: dataplexpb.EntryView_FULL,
|
||||
3: dataplexpb.EntryView_CUSTOM,
|
||||
4: dataplexpb.EntryView_ALL,
|
||||
}
|
||||
name, _ := paramsMap["name"].(string)
|
||||
entry, _ := paramsMap["entry"].(string)
|
||||
view, _ := paramsMap["view"].(int)
|
||||
@@ -132,19 +125,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("can't convert aspectTypes to array of strings: %s", err)
|
||||
}
|
||||
aspectTypes := aspectTypeSlice.([]string)
|
||||
|
||||
req := &dataplexpb.LookupEntryRequest{
|
||||
Name: name,
|
||||
View: viewMap[view],
|
||||
AspectTypes: aspectTypes,
|
||||
Entry: entry,
|
||||
}
|
||||
|
||||
result, err := source.CatalogClient().LookupEntry(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
return source.LookupEntry(ctx, name, view, aspectTypes, entry)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -18,9 +18,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
||||
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"github.com/cenkalti/backoff/v5"
|
||||
"cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -45,8 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
CatalogClient() *dataplexapi.CatalogClient
|
||||
ProjectID() string
|
||||
SearchAspectTypes(context.Context, string, int, string) ([]*dataplexpb.AspectType, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -101,61 +98,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Invoke the tool with the provided parameters
|
||||
paramsMap := params.AsMap()
|
||||
query, _ := paramsMap["query"].(string)
|
||||
pageSize := int32(paramsMap["pageSize"].(int))
|
||||
pageSize, _ := paramsMap["pageSize"].(int)
|
||||
orderBy, _ := paramsMap["orderBy"].(string)
|
||||
|
||||
// Create SearchEntriesRequest with the provided parameters
|
||||
req := &dataplexpb.SearchEntriesRequest{
|
||||
Query: query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype",
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()),
|
||||
PageSize: pageSize,
|
||||
OrderBy: orderBy,
|
||||
SemanticSearch: true,
|
||||
}
|
||||
|
||||
// Perform the search using the CatalogClient - this will return an iterator
|
||||
it := source.CatalogClient().SearchEntries(ctx, req)
|
||||
if it == nil {
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID())
|
||||
}
|
||||
|
||||
// Create an instance of exponential backoff with default values for retrying GetAspectType calls
|
||||
// InitialInterval, RandomizationFactor, Multiplier, MaxInterval = 500 ms, 0.5, 1.5, 60 s
|
||||
getAspectBackOff := backoff.NewExponentialBackOff()
|
||||
|
||||
// Iterate through the search results and call GetAspectType for each result using the resource name
|
||||
var results []*dataplexpb.AspectType
|
||||
for {
|
||||
entry, err := it.Next()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
resourceName := entry.DataplexEntry.GetEntrySource().Resource
|
||||
getAspectTypeReq := &dataplexpb.GetAspectTypeRequest{
|
||||
Name: resourceName,
|
||||
}
|
||||
|
||||
operation := func() (*dataplexpb.AspectType, error) {
|
||||
aspectType, err := source.CatalogClient().GetAspectType(ctx, getAspectTypeReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err)
|
||||
}
|
||||
return aspectType, nil
|
||||
}
|
||||
|
||||
// Retry the GetAspectType operation with exponential backoff
|
||||
aspectType, err := backoff.Retry(ctx, operation, backoff.WithBackOff(getAspectBackOff))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get aspect type after retries for entry %q: %w", resourceName, err)
|
||||
}
|
||||
|
||||
results = append(results, aspectType)
|
||||
}
|
||||
return results, nil
|
||||
return source.SearchAspectTypes(ctx, query, pageSize, orderBy)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -18,8 +18,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
||||
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -44,8 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
CatalogClient() *dataplexapi.CatalogClient
|
||||
ProjectID() string
|
||||
SearchEntries(context.Context, string, int, string) ([]*dataplexpb.SearchEntriesResult, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -100,34 +98,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
query, _ := paramsMap["query"].(string)
|
||||
pageSize := int32(paramsMap["pageSize"].(int))
|
||||
pageSize, _ := paramsMap["pageSize"].(int)
|
||||
orderBy, _ := paramsMap["orderBy"].(string)
|
||||
|
||||
req := &dataplexpb.SearchEntriesRequest{
|
||||
Query: query,
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()),
|
||||
PageSize: pageSize,
|
||||
OrderBy: orderBy,
|
||||
SemanticSearch: true,
|
||||
}
|
||||
|
||||
it := source.CatalogClient().SearchEntries(ctx, req)
|
||||
if it == nil {
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID())
|
||||
}
|
||||
|
||||
var results []*dataplexpb.SearchEntriesResult
|
||||
for {
|
||||
entry, err := it.Next()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
results = append(results, entry)
|
||||
}
|
||||
return results, nil
|
||||
return source.SearchEntries(ctx, query, pageSize, orderBy)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -48,6 +48,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
AddDocuments(context.Context, string, any, bool) (map[string]any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -134,24 +135,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
|
||||
// Get collection path
|
||||
collectionPath, ok := mapParams[collectionPathKey].(string)
|
||||
if !ok || collectionPath == "" {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter", collectionPathKey)
|
||||
}
|
||||
|
||||
// Validate collection path
|
||||
if err := util.ValidateCollectionPath(collectionPath); err != nil {
|
||||
return nil, fmt.Errorf("invalid collection path: %w", err)
|
||||
}
|
||||
|
||||
// Get document data
|
||||
documentDataRaw, ok := mapParams[documentDataKey]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter", documentDataKey)
|
||||
}
|
||||
|
||||
// Convert the document data from JSON format to Firestore format
|
||||
// The client is passed to handle referenceValue types
|
||||
documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||
@@ -164,30 +161,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if val, ok := mapParams[returnDocumentDataKey].(bool); ok {
|
||||
returnData = val
|
||||
}
|
||||
|
||||
// Get the collection reference
|
||||
collection := source.FirestoreClient().Collection(collectionPath)
|
||||
|
||||
// Add the document to the collection
|
||||
docRef, writeResult, err := collection.Add(ctx, documentData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add document: %w", err)
|
||||
}
|
||||
|
||||
// Build the response
|
||||
response := map[string]any{
|
||||
"documentPath": docRef.Path,
|
||||
"createTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
|
||||
}
|
||||
|
||||
// Add document data if requested
|
||||
if returnData {
|
||||
// Convert the document data back to simple JSON format
|
||||
simplifiedData := util.FirestoreValueToJSON(documentData)
|
||||
response["documentData"] = simplifiedData
|
||||
}
|
||||
|
||||
return response, nil
|
||||
return source.AddDocuments(ctx, collectionPath, documentData, returnData)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -46,6 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
DeleteDocuments(context.Context, []string) ([]any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -104,7 +105,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected an array", documentPathsKey)
|
||||
}
|
||||
|
||||
if len(documentPathsRaw) == 0 {
|
||||
return nil, fmt.Errorf("'%s' parameter cannot be empty", documentPathsKey)
|
||||
}
|
||||
@@ -126,45 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid document path at index %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a BulkWriter to handle multiple deletions efficiently
|
||||
bulkWriter := source.FirestoreClient().BulkWriter(ctx)
|
||||
|
||||
// Keep track of jobs for each document
|
||||
jobs := make([]*firestoreapi.BulkWriterJob, len(documentPaths))
|
||||
|
||||
// Add all delete operations to the BulkWriter
|
||||
for i, path := range documentPaths {
|
||||
docRef := source.FirestoreClient().Doc(path)
|
||||
job, err := bulkWriter.Delete(docRef)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add delete operation for document %q: %w", path, err)
|
||||
}
|
||||
jobs[i] = job
|
||||
}
|
||||
|
||||
// End the BulkWriter to execute all operations
|
||||
bulkWriter.End()
|
||||
|
||||
// Collect results
|
||||
results := make([]any, len(documentPaths))
|
||||
for i, job := range jobs {
|
||||
docData := make(map[string]any)
|
||||
docData["path"] = documentPaths[i]
|
||||
|
||||
// Wait for the job to complete and get the result
|
||||
_, err := job.Results()
|
||||
if err != nil {
|
||||
docData["success"] = false
|
||||
docData["error"] = err.Error()
|
||||
} else {
|
||||
docData["success"] = true
|
||||
}
|
||||
|
||||
results[i] = docData
|
||||
}
|
||||
|
||||
return results, nil
|
||||
return source.DeleteDocuments(ctx, documentPaths)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -46,6 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
GetDocuments(context.Context, []string) ([]any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -126,37 +127,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid document path at index %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create document references from paths
|
||||
docRefs := make([]*firestoreapi.DocumentRef, len(documentPaths))
|
||||
for i, path := range documentPaths {
|
||||
docRefs[i] = source.FirestoreClient().Doc(path)
|
||||
}
|
||||
|
||||
// Get all documents
|
||||
snapshots, err := source.FirestoreClient().GetAll(ctx, docRefs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get documents: %w", err)
|
||||
}
|
||||
|
||||
// Convert snapshots to response data
|
||||
results := make([]any, len(snapshots))
|
||||
for i, snapshot := range snapshots {
|
||||
docData := make(map[string]any)
|
||||
docData["path"] = documentPaths[i]
|
||||
docData["exists"] = snapshot.Exists()
|
||||
|
||||
if snapshot.Exists() {
|
||||
docData["data"] = snapshot.Data()
|
||||
docData["createTime"] = snapshot.CreateTime
|
||||
docData["updateTime"] = snapshot.UpdateTime
|
||||
docData["readTime"] = snapshot.ReadTime
|
||||
}
|
||||
|
||||
results[i] = docData
|
||||
}
|
||||
|
||||
return results, nil
|
||||
return source.GetDocuments(ctx, documentPaths)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -44,8 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirebaseRulesClient() *firebaserules.Service
|
||||
GetProjectId() string
|
||||
GetDatabaseId() string
|
||||
GetRules(context.Context) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -98,29 +97,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the latest release for Firestore
|
||||
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", source.GetProjectId(), source.GetDatabaseId())
|
||||
release, err := source.FirebaseRulesClient().Projects.Releases.Get(releaseName).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get latest Firestore release: %w", err)
|
||||
}
|
||||
|
||||
if release.RulesetName == "" {
|
||||
return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", source.GetProjectId(), source.GetDatabaseId())
|
||||
}
|
||||
|
||||
// Get the ruleset content
|
||||
ruleset, err := source.FirebaseRulesClient().Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get ruleset content: %w", err)
|
||||
}
|
||||
|
||||
if ruleset.Source == nil || len(ruleset.Source.Files) == 0 {
|
||||
return nil, fmt.Errorf("no rules files found in ruleset")
|
||||
}
|
||||
|
||||
return ruleset, nil
|
||||
return source.GetRules(ctx)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -46,6 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
ListCollections(context.Context, string) ([]any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -102,47 +103,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
mapParams := params.AsMap()
|
||||
|
||||
var collectionRefs []*firestoreapi.CollectionRef
|
||||
|
||||
// Check if parentPath is provided
|
||||
parentPath, hasParent := mapParams[parentPathKey].(string)
|
||||
|
||||
if hasParent && parentPath != "" {
|
||||
parentPath, _ := mapParams[parentPathKey].(string)
|
||||
if parentPath != "" {
|
||||
// Validate parent document path
|
||||
if err := util.ValidateDocumentPath(parentPath); err != nil {
|
||||
return nil, fmt.Errorf("invalid parent document path: %w", err)
|
||||
}
|
||||
|
||||
// List subcollections of the specified document
|
||||
docRef := source.FirestoreClient().Doc(parentPath)
|
||||
collectionRefs, err = docRef.Collections(ctx).GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list subcollections of document %q: %w", parentPath, err)
|
||||
}
|
||||
} else {
|
||||
// List root collections
|
||||
collectionRefs, err = source.FirestoreClient().Collections(ctx).GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list root collections: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert collection references to response data
|
||||
results := make([]any, len(collectionRefs))
|
||||
for i, collRef := range collectionRefs {
|
||||
collData := make(map[string]any)
|
||||
collData["id"] = collRef.ID
|
||||
collData["path"] = collRef.Path
|
||||
|
||||
// If this is a subcollection, include parent information
|
||||
if collRef.Parent != nil {
|
||||
collData["parent"] = collRef.Parent.Path
|
||||
}
|
||||
|
||||
results[i] = collData
|
||||
}
|
||||
|
||||
return results, nil
|
||||
return source.ListCollections(ctx, parentPath)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -36,27 +36,6 @@ const (
|
||||
defaultLimit = 100
|
||||
)
|
||||
|
||||
// Firestore operators
|
||||
var validOperators = map[string]bool{
|
||||
"<": true,
|
||||
"<=": true,
|
||||
">": true,
|
||||
">=": true,
|
||||
"==": true,
|
||||
"!=": true,
|
||||
"array-contains": true,
|
||||
"array-contains-any": true,
|
||||
"in": true,
|
||||
"not-in": true,
|
||||
}
|
||||
|
||||
// Error messages
|
||||
const (
|
||||
errFilterParseFailed = "failed to parse filters: %w"
|
||||
errQueryExecutionFailed = "failed to execute query: %w"
|
||||
errLimitParseFailed = "failed to parse limit value '%s': %w"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
@@ -74,6 +53,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
// compatibleSource defines the interface for sources that can provide a Firestore client
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
BuildQuery(string, firestoreapi.EntityFilter, []string, string, firestoreapi.Direction, int, bool) (*firestoreapi.Query, error)
|
||||
ExecuteQuery(context.Context, *firestoreapi.Query, bool) (any, error)
|
||||
}
|
||||
|
||||
// Config represents the configuration for the Firestore query tool
|
||||
@@ -139,15 +120,6 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
// SimplifiedFilter represents the simplified filter format
|
||||
type SimplifiedFilter struct {
|
||||
And []SimplifiedFilter `json:"and,omitempty"`
|
||||
Or []SimplifiedFilter `json:"or,omitempty"`
|
||||
Field string `json:"field,omitempty"`
|
||||
Op string `json:"op,omitempty"`
|
||||
Value interface{} `json:"value,omitempty"`
|
||||
}
|
||||
|
||||
// OrderByConfig represents ordering configuration
|
||||
type OrderByConfig struct {
|
||||
Field string `json:"field"`
|
||||
@@ -162,20 +134,27 @@ func (o *OrderByConfig) GetDirection() firestoreapi.Direction {
|
||||
return firestoreapi.Asc
|
||||
}
|
||||
|
||||
// QueryResult represents a document result from the query
|
||||
type QueryResult struct {
|
||||
ID string `json:"id"`
|
||||
Path string `json:"path"`
|
||||
Data map[string]any `json:"data"`
|
||||
CreateTime interface{} `json:"createTime,omitempty"`
|
||||
UpdateTime interface{} `json:"updateTime,omitempty"`
|
||||
ReadTime interface{} `json:"readTime,omitempty"`
|
||||
// SimplifiedFilter represents the simplified filter format
|
||||
type SimplifiedFilter struct {
|
||||
And []SimplifiedFilter `json:"and,omitempty"`
|
||||
Or []SimplifiedFilter `json:"or,omitempty"`
|
||||
Field string `json:"field,omitempty"`
|
||||
Op string `json:"op,omitempty"`
|
||||
Value interface{} `json:"value,omitempty"`
|
||||
}
|
||||
|
||||
// QueryResponse represents the full response including optional metrics
|
||||
type QueryResponse struct {
|
||||
Documents []QueryResult `json:"documents"`
|
||||
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
|
||||
// Firestore operators
|
||||
var validOperators = map[string]bool{
|
||||
"<": true,
|
||||
"<=": true,
|
||||
">": true,
|
||||
">=": true,
|
||||
"==": true,
|
||||
"!=": true,
|
||||
"array-contains": true,
|
||||
"array-contains-any": true,
|
||||
"in": true,
|
||||
"not-in": true,
|
||||
}
|
||||
|
||||
// Invoke executes the Firestore query based on the provided parameters
|
||||
@@ -184,34 +163,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
// Process collection path with template substitution
|
||||
collectionPath, err := parameters.PopulateTemplate("collectionPath", t.CollectionPath, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process collection path: %w", err)
|
||||
}
|
||||
|
||||
// Build the query
|
||||
query, err := t.buildQuery(source, collectionPath, paramsMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the query and return results
|
||||
return t.executeQuery(ctx, query)
|
||||
}
|
||||
|
||||
// buildQuery constructs the Firestore query from parameters
|
||||
func (t Tool) buildQuery(source compatibleSource, collectionPath string, params map[string]any) (*firestoreapi.Query, error) {
|
||||
collection := source.FirestoreClient().Collection(collectionPath)
|
||||
query := collection.Query
|
||||
|
||||
var filter firestoreapi.EntityFilter
|
||||
// Process and apply filters if template is provided
|
||||
if t.Filters != "" {
|
||||
// Apply template substitution to filters
|
||||
filtersJSON, err := parameters.PopulateTemplateWithJSON("filters", t.Filters, params)
|
||||
filtersJSON, err := parameters.PopulateTemplateWithJSON("filters", t.Filters, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process filters template: %w", err)
|
||||
}
|
||||
@@ -219,48 +182,43 @@ func (t Tool) buildQuery(source compatibleSource, collectionPath string, params
|
||||
// Parse the simplified filter format
|
||||
var simplifiedFilter SimplifiedFilter
|
||||
if err := json.Unmarshal([]byte(filtersJSON), &simplifiedFilter); err != nil {
|
||||
return nil, fmt.Errorf(errFilterParseFailed, err)
|
||||
return nil, fmt.Errorf("failed to parse filters: %w", err)
|
||||
}
|
||||
|
||||
// Convert simplified filter to Firestore filter
|
||||
if filter := t.convertToFirestoreFilter(source, simplifiedFilter); filter != nil {
|
||||
query = query.WhereEntity(filter)
|
||||
}
|
||||
filter = t.convertToFirestoreFilter(source, simplifiedFilter)
|
||||
}
|
||||
|
||||
// Process select fields
|
||||
selectFields, err := t.processSelectFields(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(selectFields) > 0 {
|
||||
query = query.Select(selectFields...)
|
||||
}
|
||||
|
||||
// Process and apply ordering
|
||||
orderBy, err := t.getOrderBy(params)
|
||||
orderBy, err := t.getOrderBy(paramsMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if orderBy != nil {
|
||||
query = query.OrderBy(orderBy.Field, orderBy.GetDirection())
|
||||
// Process select fields
|
||||
selectFields, err := t.processSelectFields(paramsMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Process and apply limit
|
||||
limit, err := t.getLimit(params)
|
||||
limit, err := t.getLimit(paramsMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
query = query.Limit(limit)
|
||||
|
||||
// Apply analyze options if enabled
|
||||
if t.AnalyzeQuery {
|
||||
query = query.WithRunOptions(firestoreapi.ExplainOptions{
|
||||
Analyze: true,
|
||||
})
|
||||
// prevent panic when accessing orderBy incase it is nil
|
||||
var orderByField string
|
||||
var orderByDirection firestoreapi.Direction
|
||||
if orderBy != nil {
|
||||
orderByField = orderBy.Field
|
||||
orderByDirection = orderBy.GetDirection()
|
||||
}
|
||||
|
||||
return &query, nil
|
||||
// Build the query
|
||||
query, err := source.BuildQuery(collectionPath, filter, selectFields, orderByField, orderByDirection, limit, t.AnalyzeQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Execute the query and return results
|
||||
return source.ExecuteQuery(ctx, query, t.AnalyzeQuery)
|
||||
}
|
||||
|
||||
// convertToFirestoreFilter converts simplified filter format to Firestore EntityFilter
|
||||
@@ -409,7 +367,7 @@ func (t Tool) getLimit(params map[string]any) (int, error) {
|
||||
if processedValue != "" {
|
||||
parsedLimit, err := strconv.Atoi(processedValue)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf(errLimitParseFailed, processedValue, err)
|
||||
return 0, fmt.Errorf("failed to parse limit value '%s': %w", processedValue, err)
|
||||
}
|
||||
limit = parsedLimit
|
||||
}
|
||||
@@ -417,78 +375,6 @@ func (t Tool) getLimit(params map[string]any) (int, error) {
|
||||
return limit, nil
|
||||
}
|
||||
|
||||
// executeQuery runs the query and formats the results
|
||||
func (t Tool) executeQuery(ctx context.Context, query *firestoreapi.Query) (any, error) {
|
||||
docIterator := query.Documents(ctx)
|
||||
docs, err := docIterator.GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(errQueryExecutionFailed, err)
|
||||
}
|
||||
|
||||
// Convert results to structured format
|
||||
results := make([]QueryResult, len(docs))
|
||||
for i, doc := range docs {
|
||||
results[i] = QueryResult{
|
||||
ID: doc.Ref.ID,
|
||||
Path: doc.Ref.Path,
|
||||
Data: doc.Data(),
|
||||
CreateTime: doc.CreateTime,
|
||||
UpdateTime: doc.UpdateTime,
|
||||
ReadTime: doc.ReadTime,
|
||||
}
|
||||
}
|
||||
|
||||
// Return with explain metrics if requested
|
||||
if t.AnalyzeQuery {
|
||||
explainMetrics, err := t.getExplainMetrics(docIterator)
|
||||
if err == nil && explainMetrics != nil {
|
||||
response := QueryResponse{
|
||||
Documents: results,
|
||||
ExplainMetrics: explainMetrics,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// getExplainMetrics extracts explain metrics from the query iterator
|
||||
func (t Tool) getExplainMetrics(docIterator *firestoreapi.DocumentIterator) (map[string]any, error) {
|
||||
explainMetrics, err := docIterator.ExplainMetrics()
|
||||
if err != nil || explainMetrics == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metricsData := make(map[string]any)
|
||||
|
||||
// Add plan summary if available
|
||||
if explainMetrics.PlanSummary != nil {
|
||||
planSummary := make(map[string]any)
|
||||
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
|
||||
metricsData["planSummary"] = planSummary
|
||||
}
|
||||
|
||||
// Add execution stats if available
|
||||
if explainMetrics.ExecutionStats != nil {
|
||||
executionStats := make(map[string]any)
|
||||
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
|
||||
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
|
||||
|
||||
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
|
||||
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
|
||||
}
|
||||
|
||||
if explainMetrics.ExecutionStats.DebugStats != nil {
|
||||
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
|
||||
}
|
||||
|
||||
metricsData["executionStats"] = executionStats
|
||||
}
|
||||
|
||||
return metricsData, nil
|
||||
}
|
||||
|
||||
// ParseParams parses and validates input parameters
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
return parameters.ParseParams(t.Parameters, data, claims)
|
||||
|
||||
@@ -69,7 +69,6 @@ const (
|
||||
errInvalidOperator = "unsupported operator: %s. Valid operators are: %v"
|
||||
errMissingFilterValue = "no value specified for filter on field '%s'"
|
||||
errOrderByParseFailed = "failed to parse orderBy: %w"
|
||||
errQueryExecutionFailed = "failed to execute query: %w"
|
||||
errTooManyFilters = "too many filters provided: %d (maximum: %d)"
|
||||
)
|
||||
|
||||
@@ -90,6 +89,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
// compatibleSource defines the interface for sources that can provide a Firestore client
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
BuildQuery(string, firestoreapi.EntityFilter, []string, string, firestoreapi.Direction, int, bool) (*firestoreapi.Query, error)
|
||||
ExecuteQuery(context.Context, *firestoreapi.Query, bool) (any, error)
|
||||
}
|
||||
|
||||
// Config represents the configuration for the Firestore query collection tool
|
||||
@@ -228,22 +229,6 @@ func (o *OrderByConfig) GetDirection() firestoreapi.Direction {
|
||||
return firestoreapi.Asc
|
||||
}
|
||||
|
||||
// QueryResult represents a document result from the query
|
||||
type QueryResult struct {
|
||||
ID string `json:"id"`
|
||||
Path string `json:"path"`
|
||||
Data map[string]any `json:"data"`
|
||||
CreateTime interface{} `json:"createTime,omitempty"`
|
||||
UpdateTime interface{} `json:"updateTime,omitempty"`
|
||||
ReadTime interface{} `json:"readTime,omitempty"`
|
||||
}
|
||||
|
||||
// QueryResponse represents the full response including optional metrics
|
||||
type QueryResponse struct {
|
||||
Documents []QueryResult `json:"documents"`
|
||||
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
|
||||
}
|
||||
|
||||
// Invoke executes the Firestore query based on the provided parameters
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
@@ -257,14 +242,37 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var filter firestoreapi.EntityFilter
|
||||
// Apply filters
|
||||
if len(queryParams.Filters) > 0 {
|
||||
filterConditions := make([]firestoreapi.EntityFilter, 0, len(queryParams.Filters))
|
||||
for _, filter := range queryParams.Filters {
|
||||
filterConditions = append(filterConditions, firestoreapi.PropertyFilter{
|
||||
Path: filter.Field,
|
||||
Operator: filter.Op,
|
||||
Value: filter.Value,
|
||||
})
|
||||
}
|
||||
|
||||
filter = firestoreapi.AndFilter{
|
||||
Filters: filterConditions,
|
||||
}
|
||||
}
|
||||
|
||||
// prevent panic incase queryParams.OrderBy is nil
|
||||
var orderByField string
|
||||
var orderByDirection firestoreapi.Direction
|
||||
if queryParams.OrderBy != nil {
|
||||
orderByField = queryParams.OrderBy.Field
|
||||
orderByDirection = queryParams.OrderBy.GetDirection()
|
||||
}
|
||||
|
||||
// Build the query
|
||||
query, err := t.buildQuery(source, queryParams)
|
||||
query, err := source.BuildQuery(queryParams.CollectionPath, filter, nil, orderByField, orderByDirection, queryParams.Limit, queryParams.AnalyzeQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the query and return results
|
||||
return t.executeQuery(ctx, query, queryParams.AnalyzeQuery)
|
||||
return source.ExecuteQuery(ctx, query, queryParams.AnalyzeQuery)
|
||||
}
|
||||
|
||||
// queryParameters holds all parsed query parameters
|
||||
@@ -380,122 +388,6 @@ func (t Tool) parseOrderBy(orderByRaw interface{}) (*OrderByConfig, error) {
|
||||
return &orderBy, nil
|
||||
}
|
||||
|
||||
// buildQuery constructs the Firestore query from parameters
|
||||
func (t Tool) buildQuery(source compatibleSource, params *queryParameters) (*firestoreapi.Query, error) {
|
||||
collection := source.FirestoreClient().Collection(params.CollectionPath)
|
||||
query := collection.Query
|
||||
|
||||
// Apply filters
|
||||
if len(params.Filters) > 0 {
|
||||
filterConditions := make([]firestoreapi.EntityFilter, 0, len(params.Filters))
|
||||
for _, filter := range params.Filters {
|
||||
filterConditions = append(filterConditions, firestoreapi.PropertyFilter{
|
||||
Path: filter.Field,
|
||||
Operator: filter.Op,
|
||||
Value: filter.Value,
|
||||
})
|
||||
}
|
||||
|
||||
query = query.WhereEntity(firestoreapi.AndFilter{
|
||||
Filters: filterConditions,
|
||||
})
|
||||
}
|
||||
|
||||
// Apply ordering
|
||||
if params.OrderBy != nil {
|
||||
query = query.OrderBy(params.OrderBy.Field, params.OrderBy.GetDirection())
|
||||
}
|
||||
|
||||
// Apply limit
|
||||
query = query.Limit(params.Limit)
|
||||
|
||||
// Apply analyze options
|
||||
if params.AnalyzeQuery {
|
||||
query = query.WithRunOptions(firestoreapi.ExplainOptions{
|
||||
Analyze: true,
|
||||
})
|
||||
}
|
||||
|
||||
return &query, nil
|
||||
}
|
||||
|
||||
// executeQuery runs the query and formats the results
|
||||
func (t Tool) executeQuery(ctx context.Context, query *firestoreapi.Query, analyzeQuery bool) (any, error) {
|
||||
docIterator := query.Documents(ctx)
|
||||
docs, err := docIterator.GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(errQueryExecutionFailed, err)
|
||||
}
|
||||
|
||||
// Convert results to structured format
|
||||
results := make([]QueryResult, len(docs))
|
||||
for i, doc := range docs {
|
||||
results[i] = QueryResult{
|
||||
ID: doc.Ref.ID,
|
||||
Path: doc.Ref.Path,
|
||||
Data: doc.Data(),
|
||||
CreateTime: doc.CreateTime,
|
||||
UpdateTime: doc.UpdateTime,
|
||||
ReadTime: doc.ReadTime,
|
||||
}
|
||||
}
|
||||
|
||||
// Return with explain metrics if requested
|
||||
if analyzeQuery {
|
||||
explainMetrics, err := t.getExplainMetrics(docIterator)
|
||||
if err == nil && explainMetrics != nil {
|
||||
response := QueryResponse{
|
||||
Documents: results,
|
||||
ExplainMetrics: explainMetrics,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Return just the documents
|
||||
resultsAny := make([]any, len(results))
|
||||
for i, r := range results {
|
||||
resultsAny[i] = r
|
||||
}
|
||||
return resultsAny, nil
|
||||
}
|
||||
|
||||
// getExplainMetrics extracts explain metrics from the query iterator
|
||||
func (t Tool) getExplainMetrics(docIterator *firestoreapi.DocumentIterator) (map[string]any, error) {
|
||||
explainMetrics, err := docIterator.ExplainMetrics()
|
||||
if err != nil || explainMetrics == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metricsData := make(map[string]any)
|
||||
|
||||
// Add plan summary if available
|
||||
if explainMetrics.PlanSummary != nil {
|
||||
planSummary := make(map[string]any)
|
||||
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
|
||||
metricsData["planSummary"] = planSummary
|
||||
}
|
||||
|
||||
// Add execution stats if available
|
||||
if explainMetrics.ExecutionStats != nil {
|
||||
executionStats := make(map[string]any)
|
||||
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
|
||||
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
|
||||
|
||||
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
|
||||
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
|
||||
}
|
||||
|
||||
if explainMetrics.ExecutionStats.DebugStats != nil {
|
||||
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
|
||||
}
|
||||
|
||||
metricsData["executionStats"] = executionStats
|
||||
}
|
||||
|
||||
return metricsData, nil
|
||||
}
|
||||
|
||||
// ParseParams parses and validates input parameters
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
return parameters.ParseParams(t.Parameters, data, claims)
|
||||
|
||||
@@ -50,6 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
UpdateDocument(context.Context, string, []firestoreapi.Update, any, bool) (map[string]any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -177,23 +178,10 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get return document data flag
|
||||
returnData := false
|
||||
if val, ok := mapParams[returnDocumentDataKey].(bool); ok {
|
||||
returnData = val
|
||||
}
|
||||
|
||||
// Get the document reference
|
||||
docRef := source.FirestoreClient().Doc(documentPath)
|
||||
|
||||
// Prepare update data
|
||||
var writeResult *firestoreapi.WriteResult
|
||||
var writeErr error
|
||||
|
||||
// Use selective field update with update mask
|
||||
updates := make([]firestoreapi.Update, 0, len(updatePaths))
|
||||
var documentData any
|
||||
if len(updatePaths) > 0 {
|
||||
// Use selective field update with update mask
|
||||
updates := make([]firestoreapi.Update, 0, len(updatePaths))
|
||||
|
||||
// Convert document data without delete markers
|
||||
dataMap, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||
@@ -220,41 +208,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
|
||||
writeResult, writeErr = docRef.Update(ctx, updates)
|
||||
} else {
|
||||
// Update all fields in the document data (merge)
|
||||
documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||
documentData, err = util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert document data: %w", err)
|
||||
}
|
||||
writeResult, writeErr = docRef.Set(ctx, documentData, firestoreapi.MergeAll)
|
||||
}
|
||||
|
||||
if writeErr != nil {
|
||||
return nil, fmt.Errorf("failed to update document: %w", writeErr)
|
||||
// Get return document data flag
|
||||
returnData := false
|
||||
if val, ok := mapParams[returnDocumentDataKey].(bool); ok {
|
||||
returnData = val
|
||||
}
|
||||
|
||||
// Build the response
|
||||
response := map[string]any{
|
||||
"documentPath": docRef.Path,
|
||||
"updateTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
|
||||
}
|
||||
|
||||
// Add document data if requested
|
||||
if returnData {
|
||||
// Fetch the updated document to return the current state
|
||||
snapshot, err := docRef.Get(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
|
||||
}
|
||||
|
||||
// Convert the document data to simple JSON format
|
||||
simplifiedData := util.FirestoreValueToJSON(snapshot.Data())
|
||||
response["documentData"] = simplifiedData
|
||||
}
|
||||
|
||||
return response, nil
|
||||
return source.UpdateDocument(ctx, documentPath, updates, documentData, returnData)
|
||||
}
|
||||
|
||||
// getFieldValue retrieves a value from a nested map using a dot-separated path
|
||||
|
||||
@@ -17,7 +17,6 @@ package firestorevalidaterules
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
@@ -50,7 +49,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirebaseRulesClient() *firebaserules.Service
|
||||
GetProjectId() string
|
||||
ValidateRules(context.Context, string) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -107,30 +106,6 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
// Issue represents a validation issue in the rules
|
||||
type Issue struct {
|
||||
SourcePosition SourcePosition `json:"sourcePosition"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
}
|
||||
|
||||
// SourcePosition represents the location of an issue in the source
|
||||
type SourcePosition struct {
|
||||
FileName string `json:"fileName,omitempty"`
|
||||
Line int64 `json:"line"` // 1-based
|
||||
Column int64 `json:"column"` // 1-based
|
||||
CurrentOffset int64 `json:"currentOffset"` // 0-based, inclusive start
|
||||
EndOffset int64 `json:"endOffset"` // 0-based, exclusive end
|
||||
}
|
||||
|
||||
// ValidationResult represents the result of rules validation
|
||||
type ValidationResult struct {
|
||||
Valid bool `json:"valid"`
|
||||
IssueCount int `json:"issueCount"`
|
||||
FormattedIssues string `json:"formattedIssues,omitempty"`
|
||||
RawIssues []Issue `json:"rawIssues,omitempty"`
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
@@ -144,114 +119,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if !ok || sourceParam == "" {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter", sourceKey)
|
||||
}
|
||||
|
||||
// Create test request
|
||||
testRequest := &firebaserules.TestRulesetRequest{
|
||||
Source: &firebaserules.Source{
|
||||
Files: []*firebaserules.File{
|
||||
{
|
||||
Name: "firestore.rules",
|
||||
Content: sourceParam,
|
||||
},
|
||||
},
|
||||
},
|
||||
// We don't need test cases for validation only
|
||||
TestSuite: &firebaserules.TestSuite{
|
||||
TestCases: []*firebaserules.TestCase{},
|
||||
},
|
||||
}
|
||||
|
||||
// Call the test API
|
||||
projectName := fmt.Sprintf("projects/%s", source.GetProjectId())
|
||||
response, err := source.FirebaseRulesClient().Projects.Test(projectName, testRequest).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate rules: %w", err)
|
||||
}
|
||||
|
||||
// Process the response
|
||||
result := t.processValidationResponse(response, sourceParam)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t Tool) processValidationResponse(response *firebaserules.TestRulesetResponse, source string) ValidationResult {
|
||||
if len(response.Issues) == 0 {
|
||||
return ValidationResult{
|
||||
Valid: true,
|
||||
IssueCount: 0,
|
||||
FormattedIssues: "✓ No errors detected. Rules are valid.",
|
||||
}
|
||||
}
|
||||
|
||||
// Convert issues to our format
|
||||
issues := make([]Issue, len(response.Issues))
|
||||
for i, issue := range response.Issues {
|
||||
issues[i] = Issue{
|
||||
Description: issue.Description,
|
||||
Severity: issue.Severity,
|
||||
SourcePosition: SourcePosition{
|
||||
FileName: issue.SourcePosition.FileName,
|
||||
Line: issue.SourcePosition.Line,
|
||||
Column: issue.SourcePosition.Column,
|
||||
CurrentOffset: issue.SourcePosition.CurrentOffset,
|
||||
EndOffset: issue.SourcePosition.EndOffset,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Format issues
|
||||
formattedIssues := t.formatRulesetIssues(issues, source)
|
||||
|
||||
return ValidationResult{
|
||||
Valid: false,
|
||||
IssueCount: len(issues),
|
||||
FormattedIssues: formattedIssues,
|
||||
RawIssues: issues,
|
||||
}
|
||||
}
|
||||
|
||||
// formatRulesetIssues formats validation issues into a human-readable string with code snippets
|
||||
func (t Tool) formatRulesetIssues(issues []Issue, rulesSource string) string {
|
||||
sourceLines := strings.Split(rulesSource, "\n")
|
||||
var formattedOutput []string
|
||||
|
||||
formattedOutput = append(formattedOutput, fmt.Sprintf("Found %d issue(s) in rules source:\n", len(issues)))
|
||||
|
||||
for _, issue := range issues {
|
||||
issueString := fmt.Sprintf("%s: %s [Ln %d, Col %d]",
|
||||
issue.Severity,
|
||||
issue.Description,
|
||||
issue.SourcePosition.Line,
|
||||
issue.SourcePosition.Column)
|
||||
|
||||
if issue.SourcePosition.Line > 0 {
|
||||
lineIndex := int(issue.SourcePosition.Line - 1) // 0-based index
|
||||
if lineIndex >= 0 && lineIndex < len(sourceLines) {
|
||||
errorLine := sourceLines[lineIndex]
|
||||
issueString += fmt.Sprintf("\n```\n%s", errorLine)
|
||||
|
||||
// Add carets if we have column and offset information
|
||||
if issue.SourcePosition.Column > 0 &&
|
||||
issue.SourcePosition.CurrentOffset >= 0 &&
|
||||
issue.SourcePosition.EndOffset > issue.SourcePosition.CurrentOffset {
|
||||
|
||||
startColumn := int(issue.SourcePosition.Column - 1) // 0-based
|
||||
errorTokenLength := int(issue.SourcePosition.EndOffset - issue.SourcePosition.CurrentOffset)
|
||||
|
||||
if startColumn >= 0 && errorTokenLength > 0 && startColumn <= len(errorLine) {
|
||||
padding := strings.Repeat(" ", startColumn)
|
||||
carets := strings.Repeat("^", errorTokenLength)
|
||||
issueString += fmt.Sprintf("\n%s%s", padding, carets)
|
||||
}
|
||||
}
|
||||
issueString += "\n```"
|
||||
}
|
||||
}
|
||||
|
||||
formattedOutput = append(formattedOutput, issueString)
|
||||
}
|
||||
|
||||
return strings.Join(formattedOutput, "\n\n")
|
||||
return source.ValidateRules(ctx, sourceParam)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -28,13 +28,13 @@ import (
|
||||
// JSONToFirestoreValue converts a JSON value with type information to a Firestore-compatible value
|
||||
// The input should be a map with a single key indicating the type (e.g., "stringValue", "integerValue")
|
||||
// If a client is provided, referenceValue types will be converted to *firestore.DocumentRef
|
||||
func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interface{}, error) {
|
||||
func JSONToFirestoreValue(value any, client *firestore.Client) (any, error) {
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case map[string]interface{}:
|
||||
case map[string]any:
|
||||
// Check for typed values
|
||||
if len(v) == 1 {
|
||||
for key, val := range v {
|
||||
@@ -92,7 +92,7 @@ func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interfac
|
||||
return nil, fmt.Errorf("timestamp value must be a string")
|
||||
case "geoPointValue":
|
||||
// Convert to LatLng
|
||||
if geoMap, ok := val.(map[string]interface{}); ok {
|
||||
if geoMap, ok := val.(map[string]any); ok {
|
||||
lat, latOk := geoMap["latitude"].(float64)
|
||||
lng, lngOk := geoMap["longitude"].(float64)
|
||||
if latOk && lngOk {
|
||||
@@ -105,9 +105,9 @@ func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interfac
|
||||
return nil, fmt.Errorf("invalid geopoint value format")
|
||||
case "arrayValue":
|
||||
// Convert array
|
||||
if arrayMap, ok := val.(map[string]interface{}); ok {
|
||||
if values, ok := arrayMap["values"].([]interface{}); ok {
|
||||
result := make([]interface{}, len(values))
|
||||
if arrayMap, ok := val.(map[string]any); ok {
|
||||
if values, ok := arrayMap["values"].([]any); ok {
|
||||
result := make([]any, len(values))
|
||||
for i, item := range values {
|
||||
converted, err := JSONToFirestoreValue(item, client)
|
||||
if err != nil {
|
||||
@@ -121,9 +121,9 @@ func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interfac
|
||||
return nil, fmt.Errorf("invalid array value format")
|
||||
case "mapValue":
|
||||
// Convert map
|
||||
if mapMap, ok := val.(map[string]interface{}); ok {
|
||||
if fields, ok := mapMap["fields"].(map[string]interface{}); ok {
|
||||
result := make(map[string]interface{})
|
||||
if mapMap, ok := val.(map[string]any); ok {
|
||||
if fields, ok := mapMap["fields"].(map[string]any); ok {
|
||||
result := make(map[string]any)
|
||||
for k, v := range fields {
|
||||
converted, err := JSONToFirestoreValue(v, client)
|
||||
if err != nil {
|
||||
@@ -160,8 +160,8 @@ func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interfac
|
||||
}
|
||||
|
||||
// convertPlainMap converts a plain map to Firestore format
|
||||
func convertPlainMap(m map[string]interface{}, client *firestore.Client) (map[string]interface{}, error) {
|
||||
result := make(map[string]interface{})
|
||||
func convertPlainMap(m map[string]any, client *firestore.Client) (map[string]any, error) {
|
||||
result := make(map[string]any)
|
||||
for k, v := range m {
|
||||
converted, err := JSONToFirestoreValue(v, client)
|
||||
if err != nil {
|
||||
@@ -172,42 +172,6 @@ func convertPlainMap(m map[string]interface{}, client *firestore.Client) (map[st
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FirestoreValueToJSON converts a Firestore value to a simplified JSON representation
|
||||
// This removes type information and returns plain values
|
||||
func FirestoreValueToJSON(value interface{}) interface{} {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case time.Time:
|
||||
return v.Format(time.RFC3339Nano)
|
||||
case *latlng.LatLng:
|
||||
return map[string]interface{}{
|
||||
"latitude": v.Latitude,
|
||||
"longitude": v.Longitude,
|
||||
}
|
||||
case []byte:
|
||||
return base64.StdEncoding.EncodeToString(v)
|
||||
case []interface{}:
|
||||
result := make([]interface{}, len(v))
|
||||
for i, item := range v {
|
||||
result[i] = FirestoreValueToJSON(item)
|
||||
}
|
||||
return result
|
||||
case map[string]interface{}:
|
||||
result := make(map[string]interface{})
|
||||
for k, val := range v {
|
||||
result[k] = FirestoreValueToJSON(val)
|
||||
}
|
||||
return result
|
||||
case *firestore.DocumentRef:
|
||||
return v.Path
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// isValidDocumentPath checks if a string is a valid Firestore document path
|
||||
// Valid paths have an even number of segments (collection/doc/collection/doc...)
|
||||
func isValidDocumentPath(path string) bool {
|
||||
|
||||
@@ -312,40 +312,6 @@ func TestJSONToFirestoreValue_IntegerFromString(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirestoreValueToJSON_RoundTrip(t *testing.T) {
|
||||
// Test round-trip conversion
|
||||
original := map[string]interface{}{
|
||||
"name": "Test",
|
||||
"count": int64(42),
|
||||
"price": 19.99,
|
||||
"active": true,
|
||||
"tags": []interface{}{"tag1", "tag2"},
|
||||
"metadata": map[string]interface{}{
|
||||
"created": time.Now(),
|
||||
},
|
||||
"nullField": nil,
|
||||
}
|
||||
|
||||
// Convert to JSON representation
|
||||
jsonRepresentation := FirestoreValueToJSON(original)
|
||||
|
||||
// Verify types are simplified
|
||||
jsonMap, ok := jsonRepresentation.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("Expected map, got %T", jsonRepresentation)
|
||||
}
|
||||
|
||||
// Time should be converted to string
|
||||
metadata, ok := jsonMap["metadata"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("metadata should be a map, got %T", jsonMap["metadata"])
|
||||
}
|
||||
_, ok = metadata["created"].(string)
|
||||
if !ok {
|
||||
t.Errorf("created should be a string, got %T", metadata["created"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONToFirestoreValue_InvalidFormats(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -16,9 +16,7 @@ package http
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
@@ -54,7 +52,7 @@ type compatibleSource interface {
|
||||
HttpDefaultHeaders() map[string]string
|
||||
HttpBaseURL() string
|
||||
HttpQueryParams() map[string]string
|
||||
Client() *http.Client
|
||||
RunRequest(*http.Request) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -259,29 +257,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
for k, v := range allHeaders {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
// Make request and fetch response
|
||||
resp, err := source.Client().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making HTTP request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var body []byte
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var data any
|
||||
if err = json.Unmarshal(body, &data); err != nil {
|
||||
// if unable to unmarshal data, return result as string.
|
||||
return string(body), nil
|
||||
}
|
||||
return data, nil
|
||||
return source.RunRequest(req)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
|
||||
dataprocpb "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
@@ -36,9 +35,7 @@ func unmarshalProto(data any, m proto.Message) error {
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetBatchControllerClient() *dataproc.BatchControllerClient
|
||||
GetProject() string
|
||||
GetLocation() string
|
||||
CreateBatch(context.Context, *dataprocpb.Batch) (map[string]any, error)
|
||||
}
|
||||
|
||||
// Config is a common config that can be used with any type of create batch tool. However, each tool
|
||||
|
||||
@@ -16,23 +16,19 @@ package createbatch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
dataprocpb "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
type BatchBuilder interface {
|
||||
Parameters() parameters.Parameters
|
||||
BuildBatch(params parameters.ParamValues) (*dataprocpb.Batch, error)
|
||||
BuildBatch(parameters.ParamValues) (*dataprocpb.Batch, error)
|
||||
}
|
||||
|
||||
func NewTool(cfg Config, originalCfg tools.ToolConfig, srcs map[string]sources.Source, builder BatchBuilder) (*Tool, error) {
|
||||
@@ -74,7 +70,6 @@ func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, par
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := source.GetBatchControllerClient()
|
||||
|
||||
batch, err := t.Builder.BuildBatch(params)
|
||||
if err != nil {
|
||||
@@ -97,46 +92,7 @@ func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, par
|
||||
}
|
||||
batch.RuntimeConfig.Version = version
|
||||
}
|
||||
|
||||
req := &dataprocpb.CreateBatchRequest{
|
||||
Parent: fmt.Sprintf("projects/%s/locations/%s", source.GetProject(), source.GetLocation()),
|
||||
Batch: batch,
|
||||
}
|
||||
|
||||
op, err := client.CreateBatch(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create batch: %w", err)
|
||||
}
|
||||
|
||||
meta, err := op.Metadata()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get create batch op metadata: %w", err)
|
||||
}
|
||||
|
||||
jsonBytes, err := protojson.Marshal(meta)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal create batch op metadata to JSON: %w", err)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(jsonBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal create batch op metadata JSON: %w", err)
|
||||
}
|
||||
|
||||
projectID, location, batchID, err := common.ExtractBatchDetails(meta.GetBatch())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error extracting batch details from name %q: %v", meta.GetBatch(), err)
|
||||
}
|
||||
consoleUrl := common.BatchConsoleURL(projectID, location, batchID)
|
||||
logsUrl := common.BatchLogsURL(projectID, location, batchID, meta.GetCreateTime().AsTime(), time.Time{})
|
||||
|
||||
wrappedResult := map[string]any{
|
||||
"opMetadata": meta,
|
||||
"consoleUrl": consoleUrl,
|
||||
"logsUrl": logsUrl,
|
||||
}
|
||||
|
||||
return wrappedResult, nil
|
||||
return source.CreateBatch(ctx, batch)
|
||||
}
|
||||
|
||||
func (t *Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -19,8 +19,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
longrunning "cloud.google.com/go/longrunning/autogen"
|
||||
"cloud.google.com/go/longrunning/autogen/longrunningpb"
|
||||
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -45,9 +44,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetOperationsClient(context.Context) (*longrunning.OperationsClient, error)
|
||||
GetProject() string
|
||||
GetLocation() string
|
||||
GetBatchControllerClient() *dataproc.BatchControllerClient
|
||||
CancelOperation(context.Context, string) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -106,32 +104,15 @@ func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, par
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := source.GetOperationsClient(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get operations client: %w", err)
|
||||
}
|
||||
|
||||
paramMap := params.AsMap()
|
||||
operation, ok := paramMap["operation"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing required parameter: operation")
|
||||
}
|
||||
|
||||
if strings.Contains(operation, "/") {
|
||||
return nil, fmt.Errorf("operation must be a short operation name without '/': %s", operation)
|
||||
}
|
||||
|
||||
req := &longrunningpb.CancelOperationRequest{
|
||||
Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", source.GetProject(), source.GetLocation(), operation),
|
||||
}
|
||||
|
||||
err = client.CancelOperation(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to cancel operation: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Cancelled [%s].", operation), nil
|
||||
return source.CancelOperation(ctx, operation)
|
||||
}
|
||||
|
||||
func (t *Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -16,19 +16,15 @@ package serverlesssparkgetbatch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
|
||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
const kind = "serverless-spark-get-batch"
|
||||
@@ -49,8 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
GetBatchControllerClient() *dataproc.BatchControllerClient
|
||||
GetProject() string
|
||||
GetLocation() string
|
||||
GetBatch(context.Context, string) (map[string]any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -109,54 +104,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := source.GetBatchControllerClient()
|
||||
|
||||
paramMap := params.AsMap()
|
||||
name, ok := paramMap["name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing required parameter: name")
|
||||
}
|
||||
|
||||
if strings.Contains(name, "/") {
|
||||
return nil, fmt.Errorf("name must be a short batch name without '/': %s", name)
|
||||
}
|
||||
|
||||
req := &dataprocpb.GetBatchRequest{
|
||||
Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", source.GetProject(), source.GetLocation(), name),
|
||||
}
|
||||
|
||||
batchPb, err := client.GetBatch(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get batch: %w", err)
|
||||
}
|
||||
|
||||
jsonBytes, err := protojson.Marshal(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal batch to JSON: %w", err)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(jsonBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal batch JSON: %w", err)
|
||||
}
|
||||
|
||||
consoleUrl, err := common.BatchConsoleURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating console url: %v", err)
|
||||
}
|
||||
logsUrl, err := common.BatchLogsURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating logs url: %v", err)
|
||||
}
|
||||
|
||||
wrappedResult := map[string]any{
|
||||
"consoleUrl": consoleUrl,
|
||||
"logsUrl": logsUrl,
|
||||
"batch": result,
|
||||
}
|
||||
|
||||
return wrappedResult, nil
|
||||
return source.GetBatch(ctx, name)
|
||||
}
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
return parameters.ParseParams(t.Parameters, data, claims)
|
||||
|
||||
@@ -17,17 +17,13 @@ package serverlesssparklistbatches
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
|
||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/iterator"
|
||||
)
|
||||
|
||||
const kind = "serverless-spark-list-batches"
|
||||
@@ -48,8 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
GetBatchControllerClient() *dataproc.BatchControllerClient
|
||||
GetProject() string
|
||||
GetLocation() string
|
||||
ListBatches(context.Context, *int, string, string) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -104,95 +99,24 @@ type Tool struct {
|
||||
Parameters parameters.Parameters
|
||||
}
|
||||
|
||||
// ListBatchesResponse is the response from the list batches API.
|
||||
type ListBatchesResponse struct {
|
||||
Batches []Batch `json:"batches"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
}
|
||||
|
||||
// Batch represents a single batch job.
|
||||
type Batch struct {
|
||||
Name string `json:"name"`
|
||||
UUID string `json:"uuid"`
|
||||
State string `json:"state"`
|
||||
Creator string `json:"creator"`
|
||||
CreateTime string `json:"createTime"`
|
||||
Operation string `json:"operation"`
|
||||
ConsoleURL string `json:"consoleUrl"`
|
||||
LogsURL string `json:"logsUrl"`
|
||||
}
|
||||
|
||||
// Invoke executes the tool's operation.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := source.GetBatchControllerClient()
|
||||
|
||||
parent := fmt.Sprintf("projects/%s/locations/%s", source.GetProject(), source.GetLocation())
|
||||
req := &dataprocpb.ListBatchesRequest{
|
||||
Parent: parent,
|
||||
OrderBy: "create_time desc",
|
||||
}
|
||||
|
||||
paramMap := params.AsMap()
|
||||
var pageSize *int
|
||||
if ps, ok := paramMap["pageSize"]; ok && ps != nil {
|
||||
req.PageSize = int32(ps.(int))
|
||||
if (req.PageSize) <= 0 {
|
||||
return nil, fmt.Errorf("pageSize must be positive: %d", req.PageSize)
|
||||
pageSizeV := ps.(int)
|
||||
if pageSizeV <= 0 {
|
||||
return nil, fmt.Errorf("pageSize must be positive: %d", pageSizeV)
|
||||
}
|
||||
pageSize = &pageSizeV
|
||||
}
|
||||
if pt, ok := paramMap["pageToken"]; ok && pt != nil {
|
||||
req.PageToken = pt.(string)
|
||||
}
|
||||
if filter, ok := paramMap["filter"]; ok && filter != nil {
|
||||
req.Filter = filter.(string)
|
||||
}
|
||||
|
||||
it := client.ListBatches(ctx, req)
|
||||
pager := iterator.NewPager(it, int(req.PageSize), req.PageToken)
|
||||
|
||||
var batchPbs []*dataprocpb.Batch
|
||||
nextPageToken, err := pager.NextPage(&batchPbs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list batches: %w", err)
|
||||
}
|
||||
|
||||
batches, err := ToBatches(batchPbs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ListBatchesResponse{Batches: batches, NextPageToken: nextPageToken}, nil
|
||||
}
|
||||
|
||||
// ToBatches converts a slice of protobuf Batch messages to a slice of Batch structs.
|
||||
func ToBatches(batchPbs []*dataprocpb.Batch) ([]Batch, error) {
|
||||
batches := make([]Batch, 0, len(batchPbs))
|
||||
for _, batchPb := range batchPbs {
|
||||
consoleUrl, err := common.BatchConsoleURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating console url: %v", err)
|
||||
}
|
||||
logsUrl, err := common.BatchLogsURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating logs url: %v", err)
|
||||
}
|
||||
batch := Batch{
|
||||
Name: batchPb.Name,
|
||||
UUID: batchPb.Uuid,
|
||||
State: batchPb.State.Enum().String(),
|
||||
Creator: batchPb.Creator,
|
||||
CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339),
|
||||
Operation: batchPb.Operation,
|
||||
ConsoleURL: consoleUrl,
|
||||
LogsURL: logsUrl,
|
||||
}
|
||||
batches = append(batches, batch)
|
||||
}
|
||||
return batches, nil
|
||||
pt, _ := paramMap["pageToken"].(string)
|
||||
filter, _ := paramMap["filter"].(string)
|
||||
return source.ListBatches(ctx, pageSize, pt, filter)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -478,9 +478,13 @@ func (ps Parameters) McpManifest() (McpToolsSchema, map[string][]string) {
|
||||
for _, p := range ps {
|
||||
name := p.GetName()
|
||||
paramManifest, authParamList := p.McpManifest()
|
||||
defaultV := p.GetDefault()
|
||||
if defaultV != nil {
|
||||
paramManifest.Default = defaultV
|
||||
}
|
||||
properties[name] = paramManifest
|
||||
// parameters that doesn't have a default value are added to the required field
|
||||
if CheckParamRequired(p.GetRequired(), p.GetDefault()) {
|
||||
if CheckParamRequired(p.GetRequired(), defaultV) {
|
||||
required = append(required, name)
|
||||
}
|
||||
if len(authParamList) > 0 {
|
||||
@@ -502,6 +506,7 @@ type ParameterManifest struct {
|
||||
Description string `json:"description"`
|
||||
AuthServices []string `json:"authSources"`
|
||||
Items *ParameterManifest `json:"items,omitempty"`
|
||||
Default any `json:"default,omitempty"`
|
||||
AdditionalProperties any `json:"additionalProperties,omitempty"`
|
||||
EmbeddedBy string `json:"embeddedBy,omitempty"`
|
||||
}
|
||||
@@ -511,6 +516,7 @@ type ParameterMcpManifest struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Items *ParameterMcpManifest `json:"items,omitempty"`
|
||||
Default any `json:"default,omitempty"`
|
||||
AdditionalProperties any `json:"additionalProperties,omitempty"`
|
||||
}
|
||||
|
||||
@@ -788,6 +794,7 @@ func (p *StringParameter) Manifest() ParameterManifest {
|
||||
Required: r,
|
||||
Description: p.Desc,
|
||||
AuthServices: authServiceNames,
|
||||
Default: p.GetDefault(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -946,6 +953,7 @@ func (p *IntParameter) Manifest() ParameterManifest {
|
||||
Required: r,
|
||||
Description: p.Desc,
|
||||
AuthServices: authServiceNames,
|
||||
Default: p.GetDefault(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1102,6 +1110,7 @@ func (p *FloatParameter) Manifest() ParameterManifest {
|
||||
Required: r,
|
||||
Description: p.Desc,
|
||||
AuthServices: authServiceNames,
|
||||
Default: p.GetDefault(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1235,6 +1244,7 @@ func (p *BooleanParameter) Manifest() ParameterManifest {
|
||||
Required: r,
|
||||
Description: p.Desc,
|
||||
AuthServices: authServiceNames,
|
||||
Default: p.GetDefault(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1430,6 +1440,7 @@ func (p *ArrayParameter) Manifest() ParameterManifest {
|
||||
Description: p.Desc,
|
||||
AuthServices: authServiceNames,
|
||||
Items: &items,
|
||||
Default: p.GetDefault(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1675,7 +1686,10 @@ func (p *MapParameter) Manifest() ParameterManifest {
|
||||
// If no valueType is given, allow any properties.
|
||||
additionalProperties = true
|
||||
}
|
||||
|
||||
var defaultV any
|
||||
if p.Default != nil {
|
||||
defaultV = *p.Default
|
||||
}
|
||||
return ParameterManifest{
|
||||
Name: p.Name,
|
||||
Type: "object",
|
||||
@@ -1683,6 +1697,7 @@ func (p *MapParameter) Manifest() ParameterManifest {
|
||||
Description: p.Desc,
|
||||
AuthServices: authServiceNames,
|
||||
AdditionalProperties: additionalProperties,
|
||||
Default: defaultV,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1625,22 +1625,22 @@ func TestParamManifest(t *testing.T) {
|
||||
{
|
||||
name: "string default",
|
||||
in: parameters.NewStringParameterWithDefault("foo-string", "foo", "bar"),
|
||||
want: parameters.ParameterManifest{Name: "foo-string", Type: "string", Required: false, Description: "bar", AuthServices: []string{}},
|
||||
want: parameters.ParameterManifest{Name: "foo-string", Type: "string", Required: false, Description: "bar", Default: "foo", AuthServices: []string{}},
|
||||
},
|
||||
{
|
||||
name: "int default",
|
||||
in: parameters.NewIntParameterWithDefault("foo-int", 1, "bar"),
|
||||
want: parameters.ParameterManifest{Name: "foo-int", Type: "integer", Required: false, Description: "bar", AuthServices: []string{}},
|
||||
want: parameters.ParameterManifest{Name: "foo-int", Type: "integer", Required: false, Description: "bar", Default: 1, AuthServices: []string{}},
|
||||
},
|
||||
{
|
||||
name: "float default",
|
||||
in: parameters.NewFloatParameterWithDefault("foo-float", 1.1, "bar"),
|
||||
want: parameters.ParameterManifest{Name: "foo-float", Type: "float", Required: false, Description: "bar", AuthServices: []string{}},
|
||||
want: parameters.ParameterManifest{Name: "foo-float", Type: "float", Required: false, Description: "bar", Default: 1.1, AuthServices: []string{}},
|
||||
},
|
||||
{
|
||||
name: "boolean default",
|
||||
in: parameters.NewBooleanParameterWithDefault("foo-bool", true, "bar"),
|
||||
want: parameters.ParameterManifest{Name: "foo-bool", Type: "boolean", Required: false, Description: "bar", AuthServices: []string{}},
|
||||
want: parameters.ParameterManifest{Name: "foo-bool", Type: "boolean", Required: false, Description: "bar", Default: true, AuthServices: []string{}},
|
||||
},
|
||||
{
|
||||
name: "array default",
|
||||
@@ -1650,6 +1650,7 @@ func TestParamManifest(t *testing.T) {
|
||||
Type: "array",
|
||||
Required: false,
|
||||
Description: "bar",
|
||||
Default: []any{"foo", "bar"},
|
||||
AuthServices: []string{},
|
||||
Items: ¶meters.ParameterManifest{Name: "foo-string", Type: "string", Required: false, Description: "bar", AuthServices: []string{}},
|
||||
},
|
||||
@@ -1841,7 +1842,7 @@ func TestMcpManifest(t *testing.T) {
|
||||
wantSchema: parameters.McpToolsSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]parameters.ParameterMcpManifest{
|
||||
"foo-string": {Type: "string", Description: "bar"},
|
||||
"foo-string": {Type: "string", Description: "bar", Default: "foo"},
|
||||
"foo-string2": {Type: "string", Description: "bar"},
|
||||
"foo-string3-auth": {Type: "string", Description: "bar"},
|
||||
"foo-int2": {Type: "integer", Description: "bar"},
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"$schema": "https://static.modelcontextprotocol.io/schemas/2025-10-17/server.schema.json",
|
||||
"$schema": "https://static.modelcontextprotocol.io/schemas/2025-12-11/server.schema.json",
|
||||
"name": "io.github.googleapis/genai-toolbox",
|
||||
"description": "MCP Toolbox for Databases enables your agent to connect to your database.",
|
||||
"title": "MCP Toolbox",
|
||||
@@ -14,11 +14,11 @@
|
||||
"url": "https://github.com/googleapis/genai-toolbox",
|
||||
"source": "github"
|
||||
},
|
||||
"version": "0.24.0",
|
||||
"version": "0.25.0",
|
||||
"packages": [
|
||||
{
|
||||
"registryType": "oci",
|
||||
"identifier": "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:0.24.0",
|
||||
"identifier": "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:0.25.0",
|
||||
"transport": {
|
||||
"type": "streamable-http",
|
||||
"url": "http://{host}:{port}/mcp"
|
||||
|
||||
@@ -194,7 +194,7 @@ func runAlloyDBToolGetTest(t *testing.T) {
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"parameters": []any{
|
||||
map[string]any{"name": "project", "type": "string", "description": "The GCP project ID to list clusters for.", "required": true, "authSources": []any{}},
|
||||
map[string]any{"name": "location", "type": "string", "description": "Optional: The location to list clusters in (e.g., 'us-central1'). Use '-' to list clusters across all locations.(Default: '-')", "required": false, "authSources": []any{}},
|
||||
map[string]any{"name": "location", "type": "string", "description": "Optional: The location to list clusters in (e.g., 'us-central1'). Use '-' to list clusters across all locations.(Default: '-')", "required": false, "default": "-", "authSources": []any{}},
|
||||
},
|
||||
"authRequired": []any{},
|
||||
},
|
||||
|
||||
@@ -431,6 +431,7 @@ func TestLooker(t *testing.T) {
|
||||
"description": "The filters for the query",
|
||||
"name": "filters",
|
||||
"required": false,
|
||||
"default": map[string]any{},
|
||||
"type": "object",
|
||||
},
|
||||
map[string]any{
|
||||
@@ -446,6 +447,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "pivots",
|
||||
"required": false,
|
||||
"type": "array",
|
||||
"default": []any{},
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -460,6 +462,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "sorts",
|
||||
"required": false,
|
||||
"type": "array",
|
||||
"default": []any{},
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -467,6 +470,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "limit",
|
||||
"required": false,
|
||||
"type": "integer",
|
||||
"default": float64(500),
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -519,6 +523,7 @@ func TestLooker(t *testing.T) {
|
||||
"description": "The filters for the query",
|
||||
"name": "filters",
|
||||
"required": false,
|
||||
"default": map[string]any{},
|
||||
"type": "object",
|
||||
},
|
||||
map[string]any{
|
||||
@@ -534,6 +539,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "pivots",
|
||||
"required": false,
|
||||
"type": "array",
|
||||
"default": []any{},
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -548,6 +554,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "sorts",
|
||||
"required": false,
|
||||
"type": "array",
|
||||
"default": []any{},
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -555,6 +562,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "limit",
|
||||
"required": false,
|
||||
"type": "integer",
|
||||
"default": float64(500),
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -607,6 +615,7 @@ func TestLooker(t *testing.T) {
|
||||
"description": "The filters for the query",
|
||||
"name": "filters",
|
||||
"required": false,
|
||||
"default": map[string]any{},
|
||||
"type": "object",
|
||||
},
|
||||
map[string]any{
|
||||
@@ -622,6 +631,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "pivots",
|
||||
"required": false,
|
||||
"type": "array",
|
||||
"default": []any{},
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -636,6 +646,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "sorts",
|
||||
"required": false,
|
||||
"type": "array",
|
||||
"default": []any{},
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -643,6 +654,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "limit",
|
||||
"required": false,
|
||||
"type": "integer",
|
||||
"default": float64(500),
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -658,6 +670,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "vis_config",
|
||||
"required": false,
|
||||
"type": "object",
|
||||
"default": map[string]any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -675,6 +688,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "title",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
"default": "",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -682,6 +696,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "desc",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
"default": "",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -689,6 +704,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "limit",
|
||||
"required": false,
|
||||
"type": "integer",
|
||||
"default": float64(100),
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -696,6 +712,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "offset",
|
||||
"required": false,
|
||||
"type": "integer",
|
||||
"default": float64(0),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -741,6 +758,7 @@ func TestLooker(t *testing.T) {
|
||||
"description": "The filters for the query",
|
||||
"name": "filters",
|
||||
"required": false,
|
||||
"default": map[string]any{},
|
||||
"type": "object",
|
||||
},
|
||||
map[string]any{
|
||||
@@ -756,6 +774,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "pivots",
|
||||
"required": false,
|
||||
"type": "array",
|
||||
"default": []any{},
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -770,6 +789,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "sorts",
|
||||
"required": false,
|
||||
"type": "array",
|
||||
"default": []any{},
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -777,6 +797,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "limit",
|
||||
"required": false,
|
||||
"type": "integer",
|
||||
"default": float64(500),
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -797,6 +818,7 @@ func TestLooker(t *testing.T) {
|
||||
"description": "The description of the Look",
|
||||
"name": "description",
|
||||
"required": false,
|
||||
"default": "",
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
@@ -804,6 +826,7 @@ func TestLooker(t *testing.T) {
|
||||
"description": "The folder id where the Look will be created. Leave blank to use the user's personal folder",
|
||||
"name": "folder",
|
||||
"required": false,
|
||||
"default": "",
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
@@ -813,6 +836,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "vis_config",
|
||||
"required": false,
|
||||
"type": "object",
|
||||
"default": map[string]any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -830,6 +854,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "title",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
"default": "",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -837,6 +862,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "desc",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
"default": "",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -844,6 +870,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "limit",
|
||||
"required": false,
|
||||
"type": "integer",
|
||||
"default": float64(100),
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -851,6 +878,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "offset",
|
||||
"required": false,
|
||||
"type": "integer",
|
||||
"default": float64(0),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -875,6 +903,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "description",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
"default": "",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -882,6 +911,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "folder",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
"default": "",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -920,6 +950,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "filter_type",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
"default": "field_filter",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -955,6 +986,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "allow_multiple_values",
|
||||
"required": false,
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -962,6 +994,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "required",
|
||||
"required": false,
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1007,6 +1040,7 @@ func TestLooker(t *testing.T) {
|
||||
"description": "The filters for the query",
|
||||
"name": "filters",
|
||||
"required": false,
|
||||
"default": map[string]any{},
|
||||
"type": "object",
|
||||
},
|
||||
map[string]any{
|
||||
@@ -1022,6 +1056,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "pivots",
|
||||
"required": false,
|
||||
"type": "array",
|
||||
"default": []any{},
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -1036,6 +1071,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "sorts",
|
||||
"required": false,
|
||||
"type": "array",
|
||||
"default": []any{},
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -1043,6 +1079,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "limit",
|
||||
"required": false,
|
||||
"type": "integer",
|
||||
"default": float64(500),
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -1064,6 +1101,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "title",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
"default": "",
|
||||
},
|
||||
map[string]any{
|
||||
"additionalProperties": true,
|
||||
@@ -1072,6 +1110,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "vis_config",
|
||||
"required": false,
|
||||
"type": "object",
|
||||
"default": map[string]any{},
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -1083,6 +1122,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "dashboard_filter",
|
||||
"required": false,
|
||||
"type": "object",
|
||||
"default": map[string]any{},
|
||||
},
|
||||
"name": "dashboard_filters",
|
||||
"required": false,
|
||||
@@ -1181,6 +1221,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "timeframe",
|
||||
"required": false,
|
||||
"type": "integer",
|
||||
"default": float64(90),
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -1188,6 +1229,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "min_queries",
|
||||
"required": false,
|
||||
"type": "integer",
|
||||
"default": float64(0),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1212,6 +1254,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "project",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
"default": "",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -1219,6 +1262,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "model",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
"default": "",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -1226,6 +1270,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "explore",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
"default": "",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -1233,6 +1278,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "timeframe",
|
||||
"required": false,
|
||||
"type": "integer",
|
||||
"default": float64(90),
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -1240,6 +1286,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "min_queries",
|
||||
"required": false,
|
||||
"type": "integer",
|
||||
"default": float64(1),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1257,6 +1304,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "devMode",
|
||||
"required": false,
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1410,6 +1458,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "type",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
"default": "",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
@@ -1417,6 +1466,7 @@ func TestLooker(t *testing.T) {
|
||||
"name": "id",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
"default": "",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -165,6 +165,7 @@ func TestNeo4jToolEndpoints(t *testing.T) {
|
||||
"type": "boolean",
|
||||
"required": false,
|
||||
"description": "If set to true, the query will be validated and information about the execution will be returned without running the query. Defaults to false.",
|
||||
"default": false,
|
||||
"authSources": []any{},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -33,8 +33,8 @@ import (
|
||||
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
|
||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
"google.golang.org/api/iterator"
|
||||
"google.golang.org/api/option"
|
||||
@@ -676,7 +676,7 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct
|
||||
filter string
|
||||
pageSize int
|
||||
numPages int
|
||||
want []serverlesssparklistbatches.Batch
|
||||
want []serverlessspark.Batch
|
||||
}{
|
||||
{name: "one page", pageSize: 2, numPages: 1, want: batch2},
|
||||
{name: "two pages", pageSize: 1, numPages: 2, want: batch2},
|
||||
@@ -701,7 +701,7 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var actual []serverlesssparklistbatches.Batch
|
||||
var actual []serverlessspark.Batch
|
||||
var pageToken string
|
||||
for i := 0; i < tc.numPages; i++ {
|
||||
request := map[string]any{
|
||||
@@ -733,7 +733,7 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
var listResponse serverlesssparklistbatches.ListBatchesResponse
|
||||
var listResponse serverlessspark.ListBatchesResponse
|
||||
if err := json.Unmarshal([]byte(result), &listResponse); err != nil {
|
||||
t.Fatalf("error unmarshalling result: %s", err)
|
||||
}
|
||||
@@ -759,7 +759,7 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct
|
||||
}
|
||||
}
|
||||
|
||||
func listBatchesRpc(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, filter string, n int, exact bool) []serverlesssparklistbatches.Batch {
|
||||
func listBatchesRpc(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, filter string, n int, exact bool) []serverlessspark.Batch {
|
||||
parent := fmt.Sprintf("projects/%s/locations/%s", serverlessSparkProject, serverlessSparkLocation)
|
||||
req := &dataprocpb.ListBatchesRequest{
|
||||
Parent: parent,
|
||||
@@ -783,7 +783,7 @@ func listBatchesRpc(t *testing.T, client *dataproc.BatchControllerClient, ctx co
|
||||
if !exact && (len(batchPbs) == 0 || len(batchPbs) > n) {
|
||||
t.Fatalf("expected between 1 and %d batches, got %d", n, len(batchPbs))
|
||||
}
|
||||
batches, err := serverlesssparklistbatches.ToBatches(batchPbs)
|
||||
batches, err := serverlessspark.ToBatches(batchPbs)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to convert batches to JSON: %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user