Compare commits

...

17 Commits

Author SHA1 Message Date
Twisha Bansal
92ad186cd7 Merge branch 'main' into anti-prompt-quick 2025-12-22 15:11:15 +05:30
Yuan Teoh
967a72da11 refactor: decouple Source from Tool (#2204)
This PR update the linking mechanism between Source and Tool.

Tools are directly linked to their Source, either by pointing to the
Source's functions or by assigning values from the source during Tool's
initialization. However, the existing approach means that any
modification to the Source after Tool's initialization might not be
reflected. To address this limitation, each tool should only store a
name reference to the Source, rather than direct link or assigned
values.

Tools will provide interface for `compatibleSource`. This will be used
to determine if a Source is compatible with the Tool.
```
type compatibleSource interface{
    Client() http.Client
    ProjectID() string
}
```

During `Invoke()`, the tool will run the following operations:
* retrieve Source from the `resourceManager` with source's named defined
in Tool's config
* validate Source via `compatibleSource interface{}`
* run the remaining `Invoke()` function. Fields that are needed is
retrieved directly from the source.

With this update, resource manager is also added as input to other
Tool's function that require access to source (e.g.
`RequiresClientAuthorization()`).
2025-12-19 21:27:55 -08:00
Yuan Teoh
7daa4111f4 fix: add import for cloudgda source (#2217) 2025-12-19 15:36:17 -08:00
Averi Kitsch
18885f6433 ci: update renovate to use dep groups (#2142)
## Description

> Should include a concise description of the changes (bug or feature),
it's
> impact, along with a summary of the solution

## PR Checklist

> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [ ] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [ ] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
  before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [ ] Ensure the tests and linter pass
- [ ] Code coverage does not decrease (if any source code was changed)
- [ ] Appropriate docs were updated (if necessary)
- [ ] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #<issue_number_goes_here>
2025-12-19 22:19:01 +00:00
Mend Renovate
21d676ed58 chore(deps): update module github.com/redis/go-redis/v9 to v9.17.2 (#1994)
This PR contains the following updates:

| Package | Change |
[Age](https://docs.renovatebot.com/merge-confidence/) |
[Confidence](https://docs.renovatebot.com/merge-confidence/) |
|---|---|---|---|
|
[github.com/redis/go-redis/v9](https://redirect.github.com/redis/go-redis)
| `v9.16.0` -> `v9.17.2` |
![age](https://developer.mend.io/api/mc/badges/age/go/github.com%2fredis%2fgo-redis%2fv9/v9.17.2?slim=true)
|
![confidence](https://developer.mend.io/api/mc/badges/confidence/go/github.com%2fredis%2fgo-redis%2fv9/v9.16.0/v9.17.2?slim=true)
|

---

### Release Notes

<details>
<summary>redis/go-redis (github.com/redis/go-redis/v9)</summary>

###
[`v9.17.2`](https://redirect.github.com/redis/go-redis/releases/tag/v9.17.2):
9.17.2

[Compare
Source](https://redirect.github.com/redis/go-redis/compare/v9.17.1...v9.17.2)

#### 🐛 Bug Fixes

- **Connection Pool**: Fixed critical race condition in turn management
that could cause connection leaks when dial goroutines complete after
request timeout
([#&#8203;3626](https://redirect.github.com/redis/go-redis/pull/3626))
by [@&#8203;cyningsun](https://redirect.github.com/cyningsun)
- **Context Timeout**: Improved context timeout calculation to use
minimum of remaining time and DialTimeout, preventing goroutines from
waiting longer than necessary
([#&#8203;3626](https://redirect.github.com/redis/go-redis/pull/3626))
by [@&#8203;cyningsun](https://redirect.github.com/cyningsun)

#### 🧰 Maintenance

- chore(deps): bump rojopolis/spellcheck-github-actions from 0.54.0 to
0.55.0
([#&#8203;3627](https://redirect.github.com/redis/go-redis/pull/3627))

#### Contributors

We'd like to thank all the contributors who worked on this release!

[@&#8203;cyningsun](https://redirect.github.com/cyningsun) and
[@&#8203;ndyakov](https://redirect.github.com/ndyakov)

###
[`v9.17.1`](https://redirect.github.com/redis/go-redis/releases/tag/v9.17.1):
9.17.1

[Compare
Source](https://redirect.github.com/redis/go-redis/compare/v9.17.0...v9.17.1)

#### 🐛 Bug Fixes

- add wait to keyless commands list
([#&#8203;3615](https://redirect.github.com/redis/go-redis/pull/3615))
by [@&#8203;marcoferrer](https://redirect.github.com/marcoferrer)
- fix(time): remove cached time optimization
([#&#8203;3611](https://redirect.github.com/redis/go-redis/pull/3611))
by [@&#8203;ndyakov](https://redirect.github.com/ndyakov)

#### 🧰 Maintenance

- chore(deps): bump golangci/golangci-lint-action from 9.0.0 to 9.1.0
([#&#8203;3609](https://redirect.github.com/redis/go-redis/pull/3609))
- chore(deps): bump actions/checkout from 5 to 6
([#&#8203;3610](https://redirect.github.com/redis/go-redis/pull/3610))
- chore(script): fix help call in tag.sh
([#&#8203;3606](https://redirect.github.com/redis/go-redis/pull/3606))
by [@&#8203;ndyakov](https://redirect.github.com/ndyakov)

#### Contributors

We'd like to thank all the contributors who worked on this release!

[@&#8203;marcoferrer](https://redirect.github.com/marcoferrer) and
[@&#8203;ndyakov](https://redirect.github.com/ndyakov)

###
[`v9.17.0`](https://redirect.github.com/redis/go-redis/releases/tag/v9.17.0):
9.17.0

[Compare
Source](https://redirect.github.com/redis/go-redis/compare/v9.16.0...v9.17.0)

#### 🚀 Highlights

##### Redis 8.4 Support

Added support for Redis 8.4, including new commands and features
([#&#8203;3572](https://redirect.github.com/redis/go-redis/pull/3572))

##### Typed Errors

Introduced typed errors for better error handling using `errors.As`
instead of string checks. Errors can now be wrapped and set to commands
in hooks without breaking library functionality
([#&#8203;3602](https://redirect.github.com/redis/go-redis/pull/3602))

##### New Commands

- **CAS/CAD Commands**: Added support for
Compare-And-Set/Compare-And-Delete operations with conditional matching
(`IFEQ`, `IFNE`, `IFDEQ`, `IFDNE`)
([#&#8203;3583](https://redirect.github.com/redis/go-redis/pull/3583),
[#&#8203;3595](https://redirect.github.com/redis/go-redis/pull/3595))
- **MSETEX**: Atomically set multiple key-value pairs with expiration
options and conditional modes
([#&#8203;3580](https://redirect.github.com/redis/go-redis/pull/3580))
- **XReadGroup CLAIM**: Consume both incoming and idle pending entries
from streams in a single call
([#&#8203;3578](https://redirect.github.com/redis/go-redis/pull/3578))
- **ACL Commands**: Added `ACLGenPass`, `ACLUsers`, and `ACLWhoAmI`
([#&#8203;3576](https://redirect.github.com/redis/go-redis/pull/3576))
- **SLOWLOG Commands**: Added `SLOWLOG LEN` and `SLOWLOG RESET`
([#&#8203;3585](https://redirect.github.com/redis/go-redis/pull/3585))
- **LATENCY Commands**: Added `LATENCY LATEST` and `LATENCY RESET`
([#&#8203;3584](https://redirect.github.com/redis/go-redis/pull/3584))

##### Search & Vector Improvements

- **Hybrid Search**: Added **EXPERIMENTAL** support for the new
`FT.HYBRID` command
([#&#8203;3573](https://redirect.github.com/redis/go-redis/pull/3573))
- **Vector Range**: Added `VRANGE` command for vector sets
([#&#8203;3543](https://redirect.github.com/redis/go-redis/pull/3543))
- **FT.INFO Enhancements**: Added vector-specific attributes in FT.INFO
response
([#&#8203;3596](https://redirect.github.com/redis/go-redis/pull/3596))

##### Connection Pool Improvements

- **Improved Connection Success Rate**: Implemented FIFO queue-based
fairness and context pattern for connection creation to prevent
premature cancellation under high concurrency
([#&#8203;3518](https://redirect.github.com/redis/go-redis/pull/3518))
- **Connection State Machine**: Resolved race conditions and improved
pool performance with proper state tracking
([#&#8203;3559](https://redirect.github.com/redis/go-redis/pull/3559))
- **Pool Performance**: Significant performance improvements with faster
semaphores, lockless hook manager, and reduced allocations (47-67%
faster Get/Put operations)
([#&#8203;3565](https://redirect.github.com/redis/go-redis/pull/3565))

##### Metrics & Observability

- **Canceled Metric Attribute**: Added 'canceled' metrics attribute to
distinguish context cancellation errors from other errors
([#&#8203;3566](https://redirect.github.com/redis/go-redis/pull/3566))

####  New Features

- Typed errors with wrapping support
([#&#8203;3602](https://redirect.github.com/redis/go-redis/pull/3602))
by [@&#8203;ndyakov](https://redirect.github.com/ndyakov)
- CAS/CAD commands (marked as experimental)
([#&#8203;3583](https://redirect.github.com/redis/go-redis/pull/3583),
[#&#8203;3595](https://redirect.github.com/redis/go-redis/pull/3595)) by
[@&#8203;ndyakov](https://redirect.github.com/ndyakov),
[@&#8203;htemelski-redis](https://redirect.github.com/htemelski-redis)
- MSETEX command support
([#&#8203;3580](https://redirect.github.com/redis/go-redis/pull/3580))
by [@&#8203;ofekshenawa](https://redirect.github.com/ofekshenawa)
- XReadGroup CLAIM argument
([#&#8203;3578](https://redirect.github.com/redis/go-redis/pull/3578))
by [@&#8203;ofekshenawa](https://redirect.github.com/ofekshenawa)
- ACL commands: GenPass, Users, WhoAmI
([#&#8203;3576](https://redirect.github.com/redis/go-redis/pull/3576))
by [@&#8203;destinyoooo](https://redirect.github.com/destinyoooo)
- SLOWLOG commands: LEN, RESET
([#&#8203;3585](https://redirect.github.com/redis/go-redis/pull/3585))
by [@&#8203;destinyoooo](https://redirect.github.com/destinyoooo)
- LATENCY commands: LATEST, RESET
([#&#8203;3584](https://redirect.github.com/redis/go-redis/pull/3584))
by [@&#8203;destinyoooo](https://redirect.github.com/destinyoooo)
- Hybrid search command (FT.HYBRID)
([#&#8203;3573](https://redirect.github.com/redis/go-redis/pull/3573))
by
[@&#8203;htemelski-redis](https://redirect.github.com/htemelski-redis)
- Vector range command (VRANGE)
([#&#8203;3543](https://redirect.github.com/redis/go-redis/pull/3543))
by [@&#8203;cxljs](https://redirect.github.com/cxljs)
- Vector-specific attributes in FT.INFO
([#&#8203;3596](https://redirect.github.com/redis/go-redis/pull/3596))
by [@&#8203;ndyakov](https://redirect.github.com/ndyakov)
- Improved connection pool success rate with FIFO queue
([#&#8203;3518](https://redirect.github.com/redis/go-redis/pull/3518))
by [@&#8203;cyningsun](https://redirect.github.com/cyningsun)
- Canceled metrics attribute for context errors
([#&#8203;3566](https://redirect.github.com/redis/go-redis/pull/3566))
by [@&#8203;pvragov](https://redirect.github.com/pvragov)

#### 🐛 Bug Fixes

- Fixed Failover Client MaintNotificationsConfig
([#&#8203;3600](https://redirect.github.com/redis/go-redis/pull/3600))
by [@&#8203;ajax16384](https://redirect.github.com/ajax16384)
- Fixed ACLGenPass function to use the bit parameter
([#&#8203;3597](https://redirect.github.com/redis/go-redis/pull/3597))
by [@&#8203;destinyoooo](https://redirect.github.com/destinyoooo)
- Return error instead of panic from commands
([#&#8203;3568](https://redirect.github.com/redis/go-redis/pull/3568))
by [@&#8203;dragneelfps](https://redirect.github.com/dragneelfps)
- Safety harness in `joinErrors` to prevent panic
([#&#8203;3577](https://redirect.github.com/redis/go-redis/pull/3577))
by [@&#8203;manisharma](https://redirect.github.com/manisharma)

####  Performance

- Connection state machine with race condition fixes
([#&#8203;3559](https://redirect.github.com/redis/go-redis/pull/3559))
by [@&#8203;ndyakov](https://redirect.github.com/ndyakov)
- Pool performance improvements: 47-67% faster Get/Put, 33% less memory,
50% fewer allocations
([#&#8203;3565](https://redirect.github.com/redis/go-redis/pull/3565))
by [@&#8203;ndyakov](https://redirect.github.com/ndyakov)

#### 🧪 Testing & Infrastructure

- Updated to Redis 8.4.0 image
([#&#8203;3603](https://redirect.github.com/redis/go-redis/pull/3603))
by [@&#8203;ndyakov](https://redirect.github.com/ndyakov)
- Added Redis 8.4-RC1-pre to CI
([#&#8203;3572](https://redirect.github.com/redis/go-redis/pull/3572))
by [@&#8203;ndyakov](https://redirect.github.com/ndyakov)
- Refactored tests for idiomatic Go
([#&#8203;3561](https://redirect.github.com/redis/go-redis/pull/3561),
[#&#8203;3562](https://redirect.github.com/redis/go-redis/pull/3562),
[#&#8203;3563](https://redirect.github.com/redis/go-redis/pull/3563)) by
[@&#8203;12ya](https://redirect.github.com/12ya)

#### 👥 Contributors

We'd like to thank all the contributors who worked on this release!

[@&#8203;12ya](https://redirect.github.com/12ya),
[@&#8203;ajax16384](https://redirect.github.com/ajax16384),
[@&#8203;cxljs](https://redirect.github.com/cxljs),
[@&#8203;cyningsun](https://redirect.github.com/cyningsun),
[@&#8203;destinyoooo](https://redirect.github.com/destinyoooo),
[@&#8203;dragneelfps](https://redirect.github.com/dragneelfps),
[@&#8203;htemelski-redis](https://redirect.github.com/htemelski-redis),
[@&#8203;manisharma](https://redirect.github.com/manisharma),
[@&#8203;ndyakov](https://redirect.github.com/ndyakov),
[@&#8203;ofekshenawa](https://redirect.github.com/ofekshenawa),
[@&#8203;pvragov](https://redirect.github.com/pvragov)

***

**Full Changelog**:
<https://github.com/redis/go-redis/compare/v9.16.0...v9.17.0>

</details>

---

### Configuration

📅 **Schedule**: Branch creation - At any time (no schedule defined),
Automerge - At any time (no schedule defined).

🚦 **Automerge**: Disabled by config. Please merge this manually once you
are satisfied.

♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the
rebase/retry checkbox.

🔕 **Ignore**: Close this PR and you won't be reminded about this update
again.

---

- [ ] <!-- rebase-check -->If you want to rebase/retry this PR, check
this box

---

This PR was generated by [Mend Renovate](https://mend.io/renovate/).
View the [repository job
log](https://developer.mend.io/github/googleapis/genai-toolbox).

<!--renovate-debug:eyJjcmVhdGVkSW5WZXIiOiI0Mi4xNi4xIiwidXBkYXRlZEluVmVyIjoiNDIuMzIuMiIsInRhcmdldEJyYW5jaCI6Im1haW4iLCJsYWJlbHMiOltdfQ==-->

Co-authored-by: Averi Kitsch <akitsch@google.com>
2025-12-19 13:59:00 -08:00
Mend Renovate
1c353a3c8e chore(deps): update module github.com/elastic/elastic-transport-go/v8 to v8.8.0 (#1989)
This PR contains the following updates:

| Package | Change |
[Age](https://docs.renovatebot.com/merge-confidence/) |
[Confidence](https://docs.renovatebot.com/merge-confidence/) |
|---|---|---|---|
|
[github.com/elastic/elastic-transport-go/v8](https://redirect.github.com/elastic/elastic-transport-go)
| `v8.7.0` -> `v8.8.0` |
![age](https://developer.mend.io/api/mc/badges/age/go/github.com%2felastic%2felastic-transport-go%2fv8/v8.8.0?slim=true)
|
![confidence](https://developer.mend.io/api/mc/badges/confidence/go/github.com%2felastic%2felastic-transport-go%2fv8/v8.7.0/v8.8.0?slim=true)
|

---

### Release Notes

<details>
<summary>elastic/elastic-transport-go
(github.com/elastic/elastic-transport-go/v8)</summary>

###
[`v8.8.0`](https://redirect.github.com/elastic/elastic-transport-go/releases/tag/v8.8.0)

[Compare
Source](https://redirect.github.com/elastic/elastic-transport-go/compare/v8.7.0...v8.8.0)

##### Features

- add a Close method to transport
([#&#8203;36](https://redirect.github.com/elastic/elastic-transport-go/issues/36))
([b2d94de](b2d94deb8a))
- add interceptor pattern
([#&#8203;35](https://redirect.github.com/elastic/elastic-transport-go/issues/35))
([c2d0c18](c2d0c18106))

</details>

---

### Configuration

📅 **Schedule**: Branch creation - At any time (no schedule defined),
Automerge - At any time (no schedule defined).

🚦 **Automerge**: Disabled by config. Please merge this manually once you
are satisfied.

♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the
rebase/retry checkbox.

🔕 **Ignore**: Close this PR and you won't be reminded about this update
again.

---

- [ ] <!-- rebase-check -->If you want to rebase/retry this PR, check
this box

---

This PR was generated by [Mend Renovate](https://mend.io/renovate/).
View the [repository job
log](https://developer.mend.io/github/googleapis/genai-toolbox).

<!--renovate-debug:eyJjcmVhdGVkSW5WZXIiOiI0Mi4xNi4xIiwidXBkYXRlZEluVmVyIjoiNDIuMzIuMiIsInRhcmdldEJyYW5jaCI6Im1haW4iLCJsYWJlbHMiOltdfQ==-->

Co-authored-by: Averi Kitsch <akitsch@google.com>
2025-12-19 20:14:14 +00:00
Mend Renovate
a02ca45ba3 chore(deps): update module github.com/godror/godror to v0.49.6 (#2199)
This PR contains the following updates:

| Package | Change |
[Age](https://docs.renovatebot.com/merge-confidence/) |
[Confidence](https://docs.renovatebot.com/merge-confidence/) |
|---|---|---|---|
| [github.com/godror/godror](https://redirect.github.com/godror/godror)
| `v0.49.4` -> `v0.49.6` |
![age](https://developer.mend.io/api/mc/badges/age/go/github.com%2fgodror%2fgodror/v0.49.6?slim=true)
|
![confidence](https://developer.mend.io/api/mc/badges/confidence/go/github.com%2fgodror%2fgodror/v0.49.4/v0.49.6?slim=true)
|

---

### Release Notes

<details>
<summary>godror/godror (github.com/godror/godror)</summary>

###
[`v0.49.6`](https://redirect.github.com/godror/godror/blob/HEAD/CHANGELOG.md#v0496)

[Compare
Source](https://redirect.github.com/godror/godror/compare/v0.49.5...v0.49.6)

##### Added

- \*bool == nil -> NULL in DB.

###
[`v0.49.5`](https://redirect.github.com/godror/godror/blob/HEAD/CHANGELOG.md#v0495)

[Compare
Source](https://redirect.github.com/godror/godror/compare/v0.49.4...v0.49.5)

- ODPI-C v5.6.4

</details>

---

### Configuration

📅 **Schedule**: Branch creation - At any time (no schedule defined),
Automerge - At any time (no schedule defined).

🚦 **Automerge**: Disabled by config. Please merge this manually once you
are satisfied.

♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the
rebase/retry checkbox.

🔕 **Ignore**: Close this PR and you won't be reminded about this update
again.

---

- [ ] <!-- rebase-check -->If you want to rebase/retry this PR, check
this box

---

This PR was generated by [Mend Renovate](https://mend.io/renovate/).
View the [repository job
log](https://developer.mend.io/github/googleapis/genai-toolbox).

<!--renovate-debug:eyJjcmVhdGVkSW5WZXIiOiI0Mi41OS4wIiwidXBkYXRlZEluVmVyIjoiNDIuNTkuMCIsInRhcmdldEJyYW5jaCI6Im1haW4iLCJsYWJlbHMiOltdfQ==-->

Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
2025-12-19 03:39:25 +00:00
Yuan Teoh
8217d1424d chore: dedup userAgentRoundTripper into util (#2198)
Dedup userAgentRoundTripper into util where userAgent related code are
placed.
2025-12-18 19:19:14 -08:00
release-please[bot]
f520b4ed8a chore(main): release 0.24.0 (#2162)
🤖 I have created a release *beep* *boop*
---


##
[0.24.0](https://github.com/googleapis/genai-toolbox/compare/v0.23.0...v0.24.0)
(2025-12-19)


### Features

* **sources/cloud-gemini-data-analytics:** Add the Gemini Data Analytics
(GDA) integration for DB NL2SQL conversion to Toolbox
([#2181](https://github.com/googleapis/genai-toolbox/issues/2181))
([aa270b2](aa270b2630))
* **source/cloudsqlmysql:** Add support for IAM authentication in Cloud
SQL MySQL source
([#2050](https://github.com/googleapis/genai-toolbox/issues/2050))
([af3d3c5](af3d3c5204))
* **sources/oracle:** Add Oracle OCI and Wallet support
([#1945](https://github.com/googleapis/genai-toolbox/issues/1945))
([8ea39ec](8ea39ec32f))
* Support combining prebuilt and custom tool configurations
([#2188](https://github.com/googleapis/genai-toolbox/issues/2188))
([5788605](5788605818))
* **tools/mysql-get-query-plan:** Add new `mysql-get-query-plan` tool
for MySQL source
([#2123](https://github.com/googleapis/genai-toolbox/issues/2123))
([0641da0](0641da0353))


### Bug Fixes

* **spanner:** Move list graphs validation to runtime
([#2154](https://github.com/googleapis/genai-toolbox/issues/2154))
([914b3ee](914b3eefda))


---
This PR was generated with [Release
Please](https://github.com/googleapis/release-please). See
[documentation](https://github.com/googleapis/release-please#release-please).

---------

Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com>
Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
2025-12-19 02:07:06 +00:00
Yuan Teoh
80315a0ebd chore: release 0.24.0 (#2201)
Release-As: 0.24.0
2025-12-19 01:44:04 +00:00
dishaprakash
5788605818 feat: Support combining prebuilt and custom tool configurations (#2188)
## Description

This PR updates the CLI to allow the --prebuilt flag to be used
simultaneously with custom tool flags (--tools-file, --tools-files, or
--tools-folder). This enables users to extend a standard prebuilt
environment with their own custom tools and configurations.

### Key changes

- Sequential Loading: Load prebuilt configurations first, then
accumulate any specified custom configurations before merging.

- Smart Defaults: Updated logic to only default to tools.yaml if no
configuration flags are provided.

- Legacy Auth Compatibility: Implemented an additive merge strategy for
authentication. Legacy authSources from custom files are merged into the
modern authServices map used by prebuilt tools.

- Strict Validation: To prevent ambiguity, the server will throw an
explicit error if a legacy authSource name conflicts with an existing
authService name (e.g., from a prebuilt config).

## PR Checklist

> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [x] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [x] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
  before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)
- [x] Make sure to add `!` if this involve a breaking change

🛠️ Fixes https://github.com/googleapis/genai-toolbox/issues/1220

---------

Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
2025-12-18 17:21:08 -08:00
gRedHeadphone
0641da0353 feat(tools/mysql-get-query-plan): tool impl + docs + tests (#2123)
## Description

Tool mysql-get-query-plan implementation, along with tests and docs.
Tool used to get information about how MySQL executes a SQL statement
(EXPLAIN).

## PR Checklist

- [x] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [x] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
  before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)
- [x] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #1692

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Averi Kitsch <akitsch@google.com>
Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
2025-12-19 01:02:16 +00:00
Yuan Teoh
c9b775d38e tests: add if exists to spanner drop table sql (#2200)
Update `DROP TABLE %table_name` to `DROP TABLE IF EXISTS %tablename`.
The drop table statement often fail to run. This halts the process and
causes context timeout, and eventually failing the integration tests.
2025-12-19 00:39:50 +00:00
Twisha Bansal
b25fd83bb9 Merge branch 'main' into anti-prompt-quick 2025-12-18 10:32:55 +05:30
Twisha Bansal
ec844d5087 restructure 2025-12-17 16:09:11 +05:30
Twisha Bansal
5978a746fd add assets 2025-12-17 16:02:23 +05:30
Twisha Bansal
5b79f712e4 docs: add prompts antigravity quickstart 2025-12-17 15:59:43 +05:30
253 changed files with 5326 additions and 6640 deletions

View File

@@ -338,7 +338,7 @@ steps:
.ci/test_with_coverage.sh \
"Spanner" \
spanner \
spanner
spanner || echo "Integration tests failed." # ignore test failures
- id: "neo4j"
name: golang:1

View File

@@ -24,5 +24,23 @@
],
pinDigests: true,
},
{
groupName: 'Go',
matchManagers: [
'gomod',
],
},
{
groupName: 'Node',
matchManagers: [
'npm',
],
},
{
groupName: 'Pip',
matchManagers: [
'pip_requirements',
],
},
],
}

View File

@@ -51,6 +51,10 @@ ignoreFiles = ["quickstart/shared", "quickstart/python", "quickstart/js", "quick
# Add a new version block here before every release
# The order of versions in this file is mirrored into the dropdown
[[params.versions]]
version = "v0.24.0"
url = "https://googleapis.github.io/genai-toolbox/v0.24.0/"
[[params.versions]]
version = "v0.23.0"
url = "https://googleapis.github.io/genai-toolbox/v0.23.0/"

View File

@@ -1,5 +1,22 @@
# Changelog
## [0.24.0](https://github.com/googleapis/genai-toolbox/compare/v0.23.0...v0.24.0) (2025-12-19)
### Features
* **sources/cloud-gemini-data-analytics:** Add the Gemini Data Analytics (GDA) integration for DB NL2SQL conversion to Toolbox ([#2181](https://github.com/googleapis/genai-toolbox/issues/2181)) ([aa270b2](https://github.com/googleapis/genai-toolbox/commit/aa270b2630da2e3d618db804ca95550445367dbc))
* **source/cloudsqlmysql:** Add support for IAM authentication in Cloud SQL MySQL source ([#2050](https://github.com/googleapis/genai-toolbox/issues/2050)) ([af3d3c5](https://github.com/googleapis/genai-toolbox/commit/af3d3c52044bea17781b89ce4ab71ff0f874ac20))
* **sources/oracle:** Add Oracle OCI and Wallet support ([#1945](https://github.com/googleapis/genai-toolbox/issues/1945)) ([8ea39ec](https://github.com/googleapis/genai-toolbox/commit/8ea39ec32fbbaa97939c626fec8c5d86040ed464))
* Support combining prebuilt and custom tool configurations ([#2188](https://github.com/googleapis/genai-toolbox/issues/2188)) ([5788605](https://github.com/googleapis/genai-toolbox/commit/57886058188aa5d2a51d5846a98bc6d8a650edd1))
* **tools/mysql-get-query-plan:** Add new `mysql-get-query-plan` tool for MySQL source ([#2123](https://github.com/googleapis/genai-toolbox/issues/2123)) ([0641da0](https://github.com/googleapis/genai-toolbox/commit/0641da0353857317113b2169e547ca69603ddfde))
### Bug Fixes
* **spanner:** Move list graphs validation to runtime ([#2154](https://github.com/googleapis/genai-toolbox/issues/2154)) ([914b3ee](https://github.com/googleapis/genai-toolbox/commit/914b3eefda40a650efe552d245369e007277dab5))
## [0.23.0](https://github.com/googleapis/genai-toolbox/compare/v0.22.0...v0.23.0) (2025-12-11)

View File

@@ -140,7 +140,7 @@ To install Toolbox as a binary:
>
> ```sh
> # see releases page for other versions
> export VERSION=0.23.0
> export VERSION=0.24.0
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox
> chmod +x toolbox
> ```
@@ -153,7 +153,7 @@ To install Toolbox as a binary:
>
> ```sh
> # see releases page for other versions
> export VERSION=0.23.0
> export VERSION=0.24.0
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/arm64/toolbox
> chmod +x toolbox
> ```
@@ -166,7 +166,7 @@ To install Toolbox as a binary:
>
> ```sh
> # see releases page for other versions
> export VERSION=0.23.0
> export VERSION=0.24.0
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/amd64/toolbox
> chmod +x toolbox
> ```
@@ -179,7 +179,7 @@ To install Toolbox as a binary:
>
> ```cmd
> :: see releases page for other versions
> set VERSION=0.23.0
> set VERSION=0.24.0
> curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe"
> ```
>
@@ -191,7 +191,7 @@ To install Toolbox as a binary:
>
> ```powershell
> # see releases page for other versions
> $VERSION = "0.23.0"
> $VERSION = "0.24.0"
> curl.exe -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v$VERSION/windows/amd64/toolbox.exe"
> ```
>
@@ -204,7 +204,7 @@ You can also install Toolbox as a container:
```sh
# see releases page for other versions
export VERSION=0.23.0
export VERSION=0.24.0
docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION
```
@@ -228,7 +228,7 @@ To install from source, ensure you have the latest version of
[Go installed](https://go.dev/doc/install), and then run the following command:
```sh
go install github.com/googleapis/genai-toolbox@v0.23.0
go install github.com/googleapis/genai-toolbox@v0.24.0
```
<!-- {x-release-please-end} -->

View File

@@ -169,6 +169,7 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqllisttables"
_ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlsql"
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlexecutesql"
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlgetqueryplan"
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllistactivequeries"
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttablefragmentation"
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttables"
@@ -234,6 +235,7 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/sources/bigtable"
_ "github.com/googleapis/genai-toolbox/internal/sources/cassandra"
_ "github.com/googleapis/genai-toolbox/internal/sources/clickhouse"
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring"
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
@@ -354,12 +356,12 @@ func NewCommand(opts ...Option) *Command {
flags.StringVarP(&cmd.cfg.Address, "address", "a", "127.0.0.1", "Address of the interface the server will listen on.")
flags.IntVarP(&cmd.cfg.Port, "port", "p", 5000, "Port the server will listen on.")
flags.StringVar(&cmd.tools_file, "tools_file", "", "File path specifying the tool configuration. Cannot be used with --prebuilt.")
flags.StringVar(&cmd.tools_file, "tools_file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.")
// deprecate tools_file
_ = flags.MarkDeprecated("tools_file", "please use --tools-file instead")
flags.StringVar(&cmd.tools_file, "tools-file", "", "File path specifying the tool configuration. Cannot be used with --prebuilt, --tools-files, or --tools-folder.")
flags.StringSliceVar(&cmd.tools_files, "tools-files", []string{}, "Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --prebuilt, --tools-file, or --tools-folder.")
flags.StringVar(&cmd.tools_folder, "tools-folder", "", "Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --prebuilt, --tools-file, or --tools-files.")
flags.StringVar(&cmd.tools_file, "tools-file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.")
flags.StringSliceVar(&cmd.tools_files, "tools-files", []string{}, "Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --tools-file, or --tools-folder.")
flags.StringVar(&cmd.tools_folder, "tools-folder", "", "Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --tools-file, or --tools-files.")
flags.Var(&cmd.cfg.LogLevel, "log-level", "Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'.")
flags.Var(&cmd.cfg.LoggingFormat, "logging-format", "Specify logging format to use. Allowed: 'standard' or 'JSON'.")
flags.BoolVar(&cmd.cfg.TelemetryGCP, "telemetry-gcp", false, "Enable exporting directly to Google Cloud Monitoring.")
@@ -367,7 +369,7 @@ func NewCommand(opts ...Option) *Command {
flags.StringVar(&cmd.cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.")
// Fetch prebuilt tools sources to customize the help description
prebuiltHelp := fmt.Sprintf(
"Use a prebuilt tool configuration by source type. Cannot be used with --tools-file. Allowed: '%s'.",
"Use a prebuilt tool configuration by source type. Allowed: '%s'.",
strings.Join(prebuiltconfigs.GetPrebuiltSources(), "', '"),
)
flags.StringVar(&cmd.prebuiltConfig, "prebuilt", "", prebuiltHelp)
@@ -461,6 +463,9 @@ func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) {
if _, exists := merged.AuthSources[name]; exists {
conflicts = append(conflicts, fmt.Sprintf("authSource '%s' (file #%d)", name, fileIndex+1))
} else {
if merged.AuthSources == nil {
merged.AuthSources = make(server.AuthServiceConfigs)
}
merged.AuthSources[name] = authSource
}
}
@@ -837,16 +842,10 @@ func run(cmd *Command) error {
}
}()
var toolsFile ToolsFile
var allToolsFiles []ToolsFile
// Load Prebuilt Configuration
if cmd.prebuiltConfig != "" {
// Make sure --prebuilt and --tools-file/--tools-files/--tools-folder flags are mutually exclusive
if cmd.tools_file != "" || len(cmd.tools_files) > 0 || cmd.tools_folder != "" {
errMsg := fmt.Errorf("--prebuilt and --tools-file/--tools-files/--tools-folder flags cannot be used simultaneously")
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
// Use prebuilt tools
buf, err := prebuiltconfigs.Get(cmd.prebuiltConfig)
if err != nil {
cmd.logger.ErrorContext(ctx, err.Error())
@@ -857,72 +856,96 @@ func run(cmd *Command) error {
// Append prebuilt.source to Version string for the User Agent
cmd.cfg.Version += "+prebuilt." + cmd.prebuiltConfig
toolsFile, err = parseToolsFile(ctx, buf)
parsed, err := parseToolsFile(ctx, buf)
if err != nil {
errMsg := fmt.Errorf("unable to parse prebuilt tool configuration: %w", err)
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
} else if len(cmd.tools_files) > 0 {
// Make sure --tools-file, --tools-files, and --tools-folder flags are mutually exclusive
if cmd.tools_file != "" || cmd.tools_folder != "" {
errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously")
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
// Use multiple tools files
cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging %d tool configuration files", len(cmd.tools_files)))
var err error
toolsFile, err = loadAndMergeToolsFiles(ctx, cmd.tools_files)
if err != nil {
cmd.logger.ErrorContext(ctx, err.Error())
return err
}
} else if cmd.tools_folder != "" {
// Make sure --tools-folder and other flags are mutually exclusive
if cmd.tools_file != "" || len(cmd.tools_files) > 0 {
errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously")
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
// Use tools folder
cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging all YAML files from directory: %s", cmd.tools_folder))
var err error
toolsFile, err = loadAndMergeToolsFolder(ctx, cmd.tools_folder)
if err != nil {
cmd.logger.ErrorContext(ctx, err.Error())
return err
}
} else {
// Set default value of tools-file flag to tools.yaml
if cmd.tools_file == "" {
cmd.tools_file = "tools.yaml"
}
// Read single tool file contents
buf, err := os.ReadFile(cmd.tools_file)
if err != nil {
errMsg := fmt.Errorf("unable to read tool file at %q: %w", cmd.tools_file, err)
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
toolsFile, err = parseToolsFile(ctx, buf)
if err != nil {
errMsg := fmt.Errorf("unable to parse tool file at %q: %w", cmd.tools_file, err)
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
allToolsFiles = append(allToolsFiles, parsed)
}
cmd.cfg.SourceConfigs, cmd.cfg.AuthServiceConfigs, cmd.cfg.ToolConfigs, cmd.cfg.ToolsetConfigs, cmd.cfg.PromptConfigs = toolsFile.Sources, toolsFile.AuthServices, toolsFile.Tools, toolsFile.Toolsets, toolsFile.Prompts
// Determine if Custom Files should be loaded
// Check for explicit custom flags
isCustomConfigured := cmd.tools_file != "" || len(cmd.tools_files) > 0 || cmd.tools_folder != ""
authSourceConfigs := toolsFile.AuthSources
// Determine if default 'tools.yaml' should be used (No prebuilt AND No custom flags)
useDefaultToolsFile := cmd.prebuiltConfig == "" && !isCustomConfigured
if useDefaultToolsFile {
cmd.tools_file = "tools.yaml"
isCustomConfigured = true
}
// Load Custom Configurations
if isCustomConfigured {
// Enforce exclusivity among custom flags (tools-file vs tools-files vs tools-folder)
if (cmd.tools_file != "" && len(cmd.tools_files) > 0) ||
(cmd.tools_file != "" && cmd.tools_folder != "") ||
(len(cmd.tools_files) > 0 && cmd.tools_folder != "") {
errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously")
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
var customTools ToolsFile
var err error
if len(cmd.tools_files) > 0 {
// Use tools-files
cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging %d tool configuration files", len(cmd.tools_files)))
customTools, err = loadAndMergeToolsFiles(ctx, cmd.tools_files)
} else if cmd.tools_folder != "" {
// Use tools-folder
cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging all YAML files from directory: %s", cmd.tools_folder))
customTools, err = loadAndMergeToolsFolder(ctx, cmd.tools_folder)
} else {
// Use single file (tools-file or default `tools.yaml`)
buf, readFileErr := os.ReadFile(cmd.tools_file)
if readFileErr != nil {
errMsg := fmt.Errorf("unable to read tool file at %q: %w", cmd.tools_file, readFileErr)
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
customTools, err = parseToolsFile(ctx, buf)
if err != nil {
err = fmt.Errorf("unable to parse tool file at %q: %w", cmd.tools_file, err)
}
}
if err != nil {
cmd.logger.ErrorContext(ctx, err.Error())
return err
}
allToolsFiles = append(allToolsFiles, customTools)
}
// Merge Everything
// This will error if custom tools collide with prebuilt tools
finalToolsFile, err := mergeToolsFiles(allToolsFiles...)
if err != nil {
cmd.logger.ErrorContext(ctx, err.Error())
return err
}
cmd.cfg.SourceConfigs = finalToolsFile.Sources
cmd.cfg.AuthServiceConfigs = finalToolsFile.AuthServices
cmd.cfg.ToolConfigs = finalToolsFile.Tools
cmd.cfg.ToolsetConfigs = finalToolsFile.Toolsets
cmd.cfg.PromptConfigs = finalToolsFile.Prompts
authSourceConfigs := finalToolsFile.AuthSources
if authSourceConfigs != nil {
cmd.logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` instead")
cmd.cfg.AuthServiceConfigs = authSourceConfigs
for k, v := range authSourceConfigs {
if _, exists := cmd.cfg.AuthServiceConfigs[k]; exists {
errMsg := fmt.Errorf("resource conflict detected: authSource '%s' has the same name as an existing authService. Please rename your authSource", k)
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
cmd.cfg.AuthServiceConfigs[k] = v
}
}
instrumentation, err := telemetry.CreateTelemetryInstrumentation(versionString)
@@ -973,9 +996,8 @@ func run(cmd *Command) error {
}()
}
watchDirs, watchedFiles := resolveWatcherInputs(cmd.tools_file, cmd.tools_files, cmd.tools_folder)
if !cmd.cfg.DisableReload {
if isCustomConfigured && !cmd.cfg.DisableReload {
watchDirs, watchedFiles := resolveWatcherInputs(cmd.tools_file, cmd.tools_files, cmd.tools_folder)
// start watching the file(s) or folder for changes to trigger dynamic reloading
go watchChanges(ctx, watchDirs, watchedFiles, s)
}

View File

@@ -92,6 +92,21 @@ func invokeCommand(args []string) (*Command, string, error) {
return c, buf.String(), err
}
// invokeCommandWithContext executes the command with a context and returns the captured output.
func invokeCommandWithContext(ctx context.Context, args []string) (*Command, string, error) {
// Capture output using a buffer
buf := new(bytes.Buffer)
c := NewCommand(WithStreams(buf, buf))
c.SetArgs(args)
c.SilenceUsage = true
c.SilenceErrors = true
c.SetContext(ctx)
err := c.Execute()
return c, buf.String(), err
}
func TestVersion(t *testing.T) {
data, err := os.ReadFile("version.txt")
if err != nil {
@@ -1755,11 +1770,6 @@ func TestMutuallyExclusiveFlags(t *testing.T) {
args []string
errString string
}{
{
desc: "--prebuilt and --tools-file",
args: []string{"--prebuilt", "alloydb", "--tools-file", "my.yaml"},
errString: "--prebuilt and --tools-file/--tools-files/--tools-folder flags cannot be used simultaneously",
},
{
desc: "--tools-file and --tools-files",
args: []string{"--tools-file", "my.yaml", "--tools-files", "a.yaml,b.yaml"},
@@ -1902,3 +1912,228 @@ func TestMergeToolsFiles(t *testing.T) {
})
}
}
func TestPrebuiltAndCustomTools(t *testing.T) {
t.Setenv("SQLITE_DATABASE", "test.db")
// Setup custom tools file
customContent := `
tools:
custom_tool:
kind: http
source: my-http
method: GET
path: /
description: "A custom tool for testing"
sources:
my-http:
kind: http
baseUrl: http://example.com
`
customFile := filepath.Join(t.TempDir(), "custom.yaml")
if err := os.WriteFile(customFile, []byte(customContent), 0644); err != nil {
t.Fatal(err)
}
// Tool Conflict File
// SQLite prebuilt has a tool named 'list_tables'
toolConflictContent := `
tools:
list_tables:
kind: http
source: my-http
method: GET
path: /
description: "Conflicting tool"
sources:
my-http:
kind: http
baseUrl: http://example.com
`
toolConflictFile := filepath.Join(t.TempDir(), "tool_conflict.yaml")
if err := os.WriteFile(toolConflictFile, []byte(toolConflictContent), 0644); err != nil {
t.Fatal(err)
}
// Source Conflict File
// SQLite prebuilt has a source named 'sqlite-source'
sourceConflictContent := `
sources:
sqlite-source:
kind: http
baseUrl: http://example.com
tools:
dummy_tool:
kind: http
source: sqlite-source
method: GET
path: /
description: "Dummy"
`
sourceConflictFile := filepath.Join(t.TempDir(), "source_conflict.yaml")
if err := os.WriteFile(sourceConflictFile, []byte(sourceConflictContent), 0644); err != nil {
t.Fatal(err)
}
// Toolset Conflict File
// SQLite prebuilt has a toolset named 'sqlite_database_tools'
toolsetConflictContent := `
sources:
dummy-src:
kind: http
baseUrl: http://example.com
tools:
dummy_tool:
kind: http
source: dummy-src
method: GET
path: /
description: "Dummy"
toolsets:
sqlite_database_tools:
- dummy_tool
`
toolsetConflictFile := filepath.Join(t.TempDir(), "toolset_conflict.yaml")
if err := os.WriteFile(toolsetConflictFile, []byte(toolsetConflictContent), 0644); err != nil {
t.Fatal(err)
}
//Legacy Auth File
authContent := `
authSources:
legacy-auth:
kind: google
clientId: "test-client-id"
`
authFile := filepath.Join(t.TempDir(), "auth.yaml")
if err := os.WriteFile(authFile, []byte(authContent), 0644); err != nil {
t.Fatal(err)
}
testCases := []struct {
desc string
args []string
wantErr bool
errString string
cfgCheck func(server.ServerConfig) error
}{
{
desc: "success mixed",
args: []string{"--prebuilt", "sqlite", "--tools-file", customFile},
wantErr: false,
cfgCheck: func(cfg server.ServerConfig) error {
if _, ok := cfg.ToolConfigs["custom_tool"]; !ok {
return fmt.Errorf("custom tool not found")
}
if _, ok := cfg.ToolConfigs["list_tables"]; !ok {
return fmt.Errorf("prebuilt tool 'list_tables' not found")
}
return nil
},
},
{
desc: "tool conflict error",
args: []string{"--prebuilt", "sqlite", "--tools-file", toolConflictFile},
wantErr: true,
errString: "resource conflicts detected",
},
{
desc: "source conflict error",
args: []string{"--prebuilt", "sqlite", "--tools-file", sourceConflictFile},
wantErr: true,
errString: "resource conflicts detected",
},
{
desc: "toolset conflict error",
args: []string{"--prebuilt", "sqlite", "--tools-file", toolsetConflictFile},
wantErr: true,
errString: "resource conflicts detected",
},
{
desc: "legacy auth additive",
args: []string{"--prebuilt", "sqlite", "--tools-file", authFile},
wantErr: false,
cfgCheck: func(cfg server.ServerConfig) error {
if _, ok := cfg.AuthServiceConfigs["legacy-auth"]; !ok {
return fmt.Errorf("legacy auth source not merged into auth services")
}
return nil
},
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
cmd, output, err := invokeCommandWithContext(ctx, tc.args)
if tc.wantErr {
if err == nil {
t.Fatalf("expected an error but got none")
}
if !strings.Contains(err.Error(), tc.errString) {
t.Errorf("expected error message to contain %q, but got %q", tc.errString, err.Error())
}
} else {
if err != nil && err != context.DeadlineExceeded && err != context.Canceled {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(output, "Server ready to serve!") {
t.Errorf("server did not start successfully (no ready message found). Output:\n%s", output)
}
if tc.cfgCheck != nil {
if err := tc.cfgCheck(cmd.cfg); err != nil {
t.Errorf("config check failed: %v", err)
}
}
}
})
}
}
func TestDefaultToolsFileBehavior(t *testing.T) {
t.Setenv("SQLITE_DATABASE", "test.db")
testCases := []struct {
desc string
args []string
expectRun bool
errString string
}{
{
desc: "no flags (defaults to tools.yaml)",
args: []string{},
expectRun: false,
errString: "tools.yaml", // Expect error because tools.yaml doesn't exist in test env
},
{
desc: "prebuilt only (skips tools.yaml)",
args: []string{"--prebuilt", "sqlite"},
expectRun: true,
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
_, output, err := invokeCommandWithContext(ctx, tc.args)
if tc.expectRun {
if err != nil && err != context.DeadlineExceeded && err != context.Canceled {
t.Fatalf("expected server start, got error: %v", err)
}
// Verify it actually started
if !strings.Contains(output, "Server ready to serve!") {
t.Errorf("server did not start successfully (no ready message found). Output:\n%s", output)
}
} else {
if err == nil {
t.Fatalf("expected error reading default file, got nil")
}
if !strings.Contains(err.Error(), tc.errString) {
t.Errorf("expected error message to contain %q, but got %q", tc.errString, err.Error())
}
}
})
}
}

View File

@@ -1 +1 @@
0.23.0
0.24.0

View File

@@ -234,7 +234,7 @@
},
"outputs": [],
"source": [
"version = \"0.23.0\" # x-release-please-version\n",
"version = \"0.24.0\" # x-release-please-version\n",
"! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
"\n",
"# Make the binary executable\n",

View File

@@ -103,7 +103,7 @@ To install Toolbox as a binary on Linux (AMD64):
```sh
# see releases page for other versions
export VERSION=0.23.0
export VERSION=0.24.0
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox
chmod +x toolbox
```
@@ -114,7 +114,7 @@ To install Toolbox as a binary on macOS (Apple Silicon):
```sh
# see releases page for other versions
export VERSION=0.23.0
export VERSION=0.24.0
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/arm64/toolbox
chmod +x toolbox
```
@@ -125,7 +125,7 @@ To install Toolbox as a binary on macOS (Intel):
```sh
# see releases page for other versions
export VERSION=0.23.0
export VERSION=0.24.0
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/amd64/toolbox
chmod +x toolbox
```
@@ -136,7 +136,7 @@ To install Toolbox as a binary on Windows (Command Prompt):
```cmd
:: see releases page for other versions
set VERSION=0.23.0
set VERSION=0.24.0
curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe"
```
@@ -146,7 +146,7 @@ To install Toolbox as a binary on Windows (PowerShell):
```powershell
# see releases page for other versions
$VERSION = "0.23.0"
$VERSION = "0.24.0"
curl.exe -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v$VERSION/windows/amd64/toolbox.exe"
```
@@ -158,7 +158,7 @@ You can also install Toolbox as a container:
```sh
# see releases page for other versions
export VERSION=0.23.0
export VERSION=0.24.0
docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION
```
@@ -177,7 +177,7 @@ To install from source, ensure you have the latest version of
[Go installed](https://go.dev/doc/install), and then run the following command:
```sh
go install github.com/googleapis/genai-toolbox@v0.23.0
go install github.com/googleapis/genai-toolbox@v0.24.0
```
{{% /tab %}}

View File

@@ -105,7 +105,7 @@ In this section, we will download Toolbox, configure our tools in a
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
```
<!-- {x-release-please-end} -->

View File

@@ -0,0 +1,239 @@
---
title: "Prompts using Antigravity"
type: docs
weight: 5
description: >
How to get started using Toolbox prompts locally with PostgreSQL and [Antigravity](https://antigravity.google/).
---
## Before you begin
This guide assumes you have already done the following:
1. Installed [PostgreSQL 16+ and the `psql` client][install-postgres].
[install-postgres]: https://www.postgresql.org/download/
## Step 1: Set up your database
In this section, we will create a database, insert some data that needs to be
accessed by our agent, and create a database user for Toolbox to connect with.
1. Connect to postgres using the `psql` command:
```bash
psql -h 127.0.0.1 -U postgres
```
Here, `postgres` denotes the default postgres superuser.
{{< notice info >}}
#### **Having trouble connecting?**
- **Password Prompt:** If you are prompted for a password for the `postgres`
user and do not know it (or a blank password doesn't work), your PostgreSQL
installation might require a password or a different authentication method.
- **`FATAL: role "postgres" does not exist`:** This error means the default
`postgres` superuser role isn't available under that name on your system.
- **`Connection refused`:** Ensure your PostgreSQL server is actually running.
You can typically check with `sudo systemctl status postgresql` and start it
with `sudo systemctl start postgresql` on Linux systems.
<br/>
#### **Common Solution**
For password issues or if the `postgres` role seems inaccessible directly, try
switching to the `postgres` operating system user first. This user often has
permission to connect without a password for local connections (this is called
peer authentication).
```bash
sudo -i -u postgres
psql -h 127.0.0.1
```
Once you are in the `psql` shell using this method, you can proceed with the
database creation steps below. Afterwards, type `\q` to exit `psql`, and then
`exit` to return to your normal user shell.
If desired, once connected to `psql` as the `postgres` OS user, you can set a
password for the `postgres` _database_ user using: `ALTER USER postgres WITH
PASSWORD 'your_chosen_password';`. This would allow direct connection with `-U
postgres` and a password next time.
{{< /notice >}}
1. Create a new database and a new user:
{{< notice tip >}}
For a real application, it's best to follow the principle of least permission
and only grant the privileges your application needs.
{{< /notice >}}
```sql
CREATE USER toolbox_user WITH PASSWORD 'my-password';
CREATE DATABASE toolbox_db;
GRANT ALL PRIVILEGES ON DATABASE toolbox_db TO toolbox_user;
ALTER DATABASE toolbox_db OWNER TO toolbox_user;
```
1. End the database session:
```bash
\q
```
(If you used `sudo -i -u postgres` and then `psql`, remember you might also
need to type `exit` after `\q` to leave the `postgres` user's shell
session.)
1. Connect to your database with your new user:
```bash
psql -h 127.0.0.1 -U toolbox_user -d toolbox_db
```
1. Create the required tables using the following commands:
```sql
CREATE TABLE users (
id SERIAL PRIMARY KEY,
username VARCHAR(50) NOT NULL,
email VARCHAR(100) UNIQUE NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE TABLE restaurants (
id SERIAL PRIMARY KEY,
name VARCHAR(100) NOT NULL,
location VARCHAR(100)
);
CREATE TABLE reviews (
id SERIAL PRIMARY KEY,
user_id INT REFERENCES users(id),
restaurant_id INT REFERENCES restaurants(id),
rating INT CHECK (rating >= 1 AND rating <= 5),
review_text TEXT,
is_published BOOLEAN DEFAULT false,
moderation_status VARCHAR(50) DEFAULT 'pending_manual_review',
created_at TIMESTAMPTZ DEFAULT NOW()
);
```
1. Insert dummy data into the tables.
```sql
INSERT INTO users (id, username, email) VALUES
(123, 'jane_d', 'jane.d@example.com'),
(124, 'john_s', 'john.s@example.com'),
(125, 'sam_b', 'sam.b@example.com');
INSERT INTO restaurants (id, name, location) VALUES
(455, 'Pizza Palace', '123 Main St'),
(456, 'The Corner Bistro', '456 Oak Ave'),
(457, 'Sushi Spot', '789 Pine Ln');
INSERT INTO reviews (user_id, restaurant_id, rating, review_text, is_published, moderation_status) VALUES
(124, 455, 5, 'Best pizza in town! The crust was perfect.', true, 'approved'),
(125, 457, 4, 'Great sushi, very fresh. A bit pricey but worth it.', true, 'approved'),
(123, 457, 5, 'Absolutely loved the dragon roll. Will be back!', true, 'approved'),
(123, 456, 4, 'The atmosphere was lovely and the food was great. My photo upload might have been weird though.', false, 'pending_manual_review'),
(125, 456, 1, 'This review contains inappropriate language.', false, 'rejected');
```
1. End the database session:
```bash
\q
```
## Step 2: Configure Toolbox
Create a file named `tools.yaml`. This file defines the database connection, the
SQL tools available, and the prompts the agents will use.
```yaml
sources:
my-foodiefind-db:
kind: postgres
host: 127.0.0.1
port: 5432
database: toolbox_db
user: toolbox_user
password: my-password
tools:
find_user_by_email:
kind: postgres-sql
source: my-foodiefind-db
description: Find a user's ID by their email address.
parameters:
- name: email
type: string
description: The email address of the user to find.
statement: SELECT id FROM users WHERE email = $1;
find_restaurant_by_name:
kind: postgres-sql
source: my-foodiefind-db
description: Find a restaurant's ID by its exact name.
parameters:
- name: name
type: string
description: The name of the restaurant to find.
statement: SELECT id FROM restaurants WHERE name = $1;
find_review_by_user_and_restaurant:
kind: postgres-sql
source: my-foodiefind-db
description: Find the full record for a specific review using the user's ID and the restaurant's ID.
parameters:
- name: user_id
type: integer
description: The numerical ID of the user.
- name: restaurant_id
type: integer
description: The numerical ID of the restaurant.
statement: SELECT * FROM reviews WHERE user_id = $1 AND restaurant_id = $2;
prompts:
investigate_missing_review:
description: "Investigates a user's missing review by finding the user, restaurant, and the review itself, then analyzing its status."
arguments:
- name: "user_email"
description: "The email of the user who wrote the review."
- name: "restaurant_name"
description: "The name of the restaurant being reviewed."
messages:
- content: >-
**Goal:** Find the review written by the user with email '{{.user_email}}' for the restaurant named '{{.restaurant_name}}' and understand its status.
**Workflow:**
1. Use the `find_user_by_email` tool with the email '{{.user_email}}' to get the `user_id`.
2. Use the `find_restaurant_by_name` tool with the name '{{.restaurant_name}}' to get the `restaurant_id`.
3. Use the `find_review_by_user_and_restaurant` tool with the `user_id` and `restaurant_id` you just found.
4. Analyze the results from the final tool call. Examine the `is_published` and `moderation_status` fields and explain the review's status to the user in a clear, human-readable sentence.
```
## Step 3: Connect to Antigravity
Configure the Antigravity to talk to your local Toolbox MCP server.
1. Click on "MCP Servers" in the Agent Manager.
![agent_panel](./agent_panel.png)
2. Search for and click on MCP Toolbox for Databases.
![search_toolbox](./search.png)
3. Click the install button and enter the correct path to the `tools.yaml` file.
![install_toolbox](./install.png)
4. You should be able to see your tools in the tools tab.
![tools](./tools.png)
5. Now you can type your query into the agent panel.
```
Investigate the missing review from jane.d@example.com for The corner bistro
```
![quickstart_result](./quickstart_result.png)

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 197 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 145 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 116 KiB

View File

@@ -13,7 +13,7 @@ In this section, we will download Toolbox, configure our tools in a
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
```
<!-- {x-release-please-end} -->

View File

@@ -49,19 +49,19 @@ to expose your developer assistant tools to a Looker instance:
{{< tabpane persist=header >}}
{{< tab header="linux/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
{{< /tab >}}
{{< tab header="darwin/arm64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
{{< /tab >}}
{{< tab header="darwin/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
{{< /tab >}}
{{< tab header="windows/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
{{< /tab >}}
{{< /tabpane >}}
<!-- {x-release-please-end} -->

View File

@@ -45,19 +45,19 @@ instance:
<!-- {x-release-please-start-version} -->
{{< tabpane persist=header >}}
{{< tab header="linux/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
{{< /tab >}}
{{< tab header="darwin/arm64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
{{< /tab >}}
{{< tab header="darwin/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
{{< /tab >}}
{{< tab header="windows/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
{{< /tab >}}
{{< /tabpane >}}
<!-- {x-release-please-end} -->

View File

@@ -43,19 +43,19 @@ expose your developer assistant tools to a MySQL instance:
<!-- {x-release-please-start-version} -->
{{< tabpane persist=header >}}
{{< tab header="linux/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
{{< /tab >}}
{{< tab header="darwin/arm64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
{{< /tab >}}
{{< tab header="darwin/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
{{< /tab >}}
{{< tab header="windows/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
{{< /tab >}}
{{< /tabpane >}}
<!-- {x-release-please-end} -->

View File

@@ -44,19 +44,19 @@ expose your developer assistant tools to a Neo4j instance:
<!-- {x-release-please-start-version} -->
{{< tabpane persist=header >}}
{{< tab header="linux/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
{{< /tab >}}
{{< tab header="darwin/arm64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
{{< /tab >}}
{{< tab header="darwin/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
{{< /tab >}}
{{< tab header="windows/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
{{< /tab >}}
{{< /tabpane >}}
<!-- {x-release-please-end} -->

View File

@@ -56,19 +56,19 @@ Omni](https://cloud.google.com/alloydb/omni/current/docs/overview).
<!-- {x-release-please-start-version} -->
{{< tabpane persist=header >}}
{{< tab header="linux/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
{{< /tab >}}
{{< tab header="darwin/arm64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
{{< /tab >}}
{{< tab header="darwin/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
{{< /tab >}}
{{< tab header="windows/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
{{< /tab >}}
{{< /tabpane >}}
<!-- {x-release-please-end} -->

View File

@@ -43,19 +43,19 @@ to expose your developer assistant tools to a SQLite instance:
<!-- {x-release-please-start-version} -->
{{< tabpane persist=header >}}
{{< tab header="linux/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/linux/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox
{{< /tab >}}
{{< tab header="darwin/arm64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/arm64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox
{{< /tab >}}
{{< tab header="darwin/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/darwin/amd64/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox
{{< /tab >}}
{{< tab header="windows/amd64" lang="bash" >}}
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/windows/amd64/toolbox.exe
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe
{{< /tab >}}
{{< /tabpane >}}
<!-- {x-release-please-end} -->

View File

@@ -16,14 +16,14 @@ description: >
| | `--log-level` | Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'. | `info` |
| | `--logging-format` | Specify logging format to use. Allowed: 'standard' or 'JSON'. | `standard` |
| `-p` | `--port` | Port the server will listen on. | `5000` |
| | `--prebuilt` | Use a prebuilt tool configuration by source type. Cannot be used with --tools-file. See [Prebuilt Tools Reference](prebuilt-tools.md) for allowed values. | |
| | `--prebuilt` | Use a prebuilt tool configuration by source type. See [Prebuilt Tools Reference](prebuilt-tools.md) for allowed values. | |
| | `--stdio` | Listens via MCP STDIO instead of acting as a remote HTTP server. | |
| | `--telemetry-gcp` | Enable exporting directly to Google Cloud Monitoring. | |
| | `--telemetry-otlp` | Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318') | |
| | `--telemetry-service-name` | Sets the value of the service.name resource attribute for telemetry data. | `toolbox` |
| | `--tools-file` | File path specifying the tool configuration. Cannot be used with --prebuilt, --tools-files, or --tools-folder. | |
| | `--tools-files` | Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --prebuilt, --tools-file, or --tools-folder. | |
| | `--tools-folder` | Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --prebuilt, --tools-file, or --tools-files. | |
| | `--tools-file` | File path specifying the tool configuration. Cannot be used with --tools-files or --tools-folder. | |
| | `--tools-files` | Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --tools-file or --tools-folder. | |
| | `--tools-folder` | Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --tools-file or --tools-files. | |
| | `--ui` | Launches the Toolbox UI web server. | |
| | `--allowed-origins` | Specifies a list of origins permitted to access this server. | `*` |
| `-v` | `--version` | version for toolbox | |
@@ -46,6 +46,9 @@ description: >
```bash
# Basic server with custom port configuration
./toolbox --tools-file "tools.yaml" --port 8080
# Server with prebuilt + custom tools configurations
./toolbox --tools-file tools.yaml --prebuilt alloydb-postgres
```
### Tool Configuration Sources
@@ -72,8 +75,8 @@ The CLI supports multiple mutually exclusive ways to specify tool configurations
{{< notice tip >}}
The CLI enforces mutual exclusivity between configuration source flags,
preventing simultaneous use of `--prebuilt` with file-based options, and
ensuring only one of `--tools-file`, `--tools-files`, or `--tools-folder` is
preventing simultaneous use of the file-based options ensuring only one of
`--tools-file`, `--tools-files`, or `--tools-folder` is
used at a time.
{{< /notice >}}

View File

@@ -13,6 +13,12 @@ allowing developers to interact with and take action on databases.
See guides, [Connect from your IDE](../how-to/connect-ide/_index.md), for
details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP.
{{< notice tip >}}
You can now use `--prebuilt` along `--tools-file`, `--tools-files`, or
`--tools-folder` to combine prebuilt configs with custom tools.
See [Usage Examples](../reference/cli.md#examples).
{{< /notice >}}
## AlloyDB Postgres
* `--prebuilt` value: `alloydb-postgres`

View File

@@ -31,6 +31,9 @@ to a database by following these instructions][csql-mysql-quickstart].
- [`mysql-list-active-queries`](../tools/mysql/mysql-list-active-queries.md)
List active queries in Cloud SQL for MySQL.
- [`mysql-get-query-plan`](../tools/mysql/mysql-get-query-plan.md)
Provide information about how MySQL executes a SQL statement (EXPLAIN).
- [`mysql-list-tables`](../tools/mysql/mysql-list-tables.md)
List tables in a Cloud SQL for MySQL database.

View File

@@ -25,6 +25,9 @@ reliability, performance, and ease of use.
- [`mysql-list-active-queries`](../tools/mysql/mysql-list-active-queries.md)
List active queries in MySQL.
- [`mysql-get-query-plan`](../tools/mysql/mysql-get-query-plan.md)
Provide information about how MySQL executes a SQL statement (EXPLAIN).
- [`mysql-list-tables`](../tools/mysql/mysql-list-tables.md)
List tables in a MySQL database.

View File

@@ -0,0 +1,39 @@
---
title: "mysql-get-query-plan"
type: docs
weight: 1
description: >
A "mysql-get-query-plan" tool gets the execution plan for a SQL statement against a MySQL
database.
aliases:
- /resources/tools/mysql-get-query-plan
---
## About
A `mysql-get-query-plan` tool gets the execution plan for a SQL statement against a MySQL
database. It's compatible with any of the following sources:
- [cloud-sql-mysql](../../sources/cloud-sql-mysql.md)
- [mysql](../../sources/mysql.md)
`mysql-get-query-plan` takes one input parameter `sql_statement` and gets the execution plan for the SQL
statement against the `source`.
## Example
```yaml
tools:
get_query_plan_tool:
kind: mysql-get-query-plan
source: my-mysql-instance
description: Use this tool to get the execution plan for a sql statement.
```
## Reference
| **field** | **type** | **required** | **description** |
|-------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------|
| kind | string | true | Must be "mysql-get-query-plan". |
| source | string | true | Name of the source the SQL should execute on. |
| description | string | true | Description of the tool that is passed to the LLM. |

View File

@@ -771,7 +771,7 @@
},
"outputs": [],
"source": [
"version = \"0.23.0\" # x-release-please-version\n",
"version = \"0.24.0\" # x-release-please-version\n",
"! curl -L -o /content/toolbox https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
"\n",
"# Make the binary executable\n",

View File

@@ -123,7 +123,7 @@ In this section, we will download and install the Toolbox binary.
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
export VERSION="0.23.0"
export VERSION="0.24.0"
curl -O https://storage.googleapis.com/genai-toolbox/v$VERSION/$OS/toolbox
```
<!-- {x-release-please-end} -->

View File

@@ -220,7 +220,7 @@
},
"outputs": [],
"source": [
"version = \"0.23.0\" # x-release-please-version\n",
"version = \"0.24.0\" # x-release-please-version\n",
"! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
"\n",
"# Make the binary executable\n",

View File

@@ -179,7 +179,7 @@ to use BigQuery, and then run the Toolbox server.
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
```
<!-- {x-release-please-end} -->

View File

@@ -98,7 +98,7 @@ In this section, we will download Toolbox, configure our tools in a
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
```
<!-- {x-release-please-end} -->

View File

@@ -34,7 +34,7 @@ In this section, we will download Toolbox and run the Toolbox server.
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
```
<!-- {x-release-please-end} -->

View File

@@ -48,7 +48,7 @@ In this section, we will download Toolbox and run the Toolbox server.
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
```
<!-- {x-release-please-end} -->

View File

@@ -34,7 +34,7 @@ In this section, we will download Toolbox and run the Toolbox server.
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O https://storage.googleapis.com/genai-toolbox/v0.23.0/$OS/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox
```
<!-- {x-release-please-end} -->

View File

@@ -1,6 +1,6 @@
{
"name": "mcp-toolbox-for-databases",
"version": "0.23.0",
"version": "0.24.0",
"description": "MCP Toolbox for Databases is an open-source MCP server for more than 30 different datasources.",
"contextFileName": "MCP-TOOLBOX-EXTENSION.md"
}

6
go.mod
View File

@@ -22,7 +22,7 @@ require (
github.com/cenkalti/backoff/v5 v5.0.3
github.com/couchbase/gocb/v2 v2.11.1
github.com/couchbase/tools-common/http v1.0.9
github.com/elastic/elastic-transport-go/v8 v8.7.0
github.com/elastic/elastic-transport-go/v8 v8.8.0
github.com/elastic/go-elasticsearch/v9 v9.2.0
github.com/fsnotify/fsnotify v1.9.0
github.com/go-chi/chi/v5 v5.2.3
@@ -33,7 +33,7 @@ require (
github.com/go-playground/validator/v10 v10.28.0
github.com/go-sql-driver/mysql v1.9.3
github.com/goccy/go-yaml v1.18.0
github.com/godror/godror v0.49.4
github.com/godror/godror v0.49.6
github.com/google/go-cmp v0.7.0
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.7.6
@@ -42,7 +42,7 @@ require (
github.com/microsoft/go-mssqldb v1.9.3
github.com/nakagami/firebirdsql v0.9.15
github.com/neo4j/neo4j-go-driver/v5 v5.28.4
github.com/redis/go-redis/v9 v9.16.0
github.com/redis/go-redis/v9 v9.17.2
github.com/sijms/go-ora/v2 v2.9.0
github.com/spf13/cobra v1.10.1
github.com/thlib/go-timezone-local v0.0.7

12
go.sum
View File

@@ -822,8 +822,8 @@ github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/elastic/elastic-transport-go/v8 v8.7.0 h1:OgTneVuXP2uip4BA658Xi6Hfw+PeIOod2rY3GVMGoVE=
github.com/elastic/elastic-transport-go/v8 v8.7.0/go.mod h1:YLHer5cj0csTzNFXoNQ8qhtGY1GTvSqPnKWKaqQE3Hk=
github.com/elastic/elastic-transport-go/v8 v8.8.0 h1:7k1Ua+qluFr6p1jfJjGDl97ssJS/P7cHNInzfxgBQAo=
github.com/elastic/elastic-transport-go/v8 v8.8.0/go.mod h1:YLHer5cj0csTzNFXoNQ8qhtGY1GTvSqPnKWKaqQE3Hk=
github.com/elastic/go-elasticsearch/v9 v9.2.0 h1:COeL/g20+ixnUbffe4Wfbu88emrHjAq/LhVfmrjqRQs=
github.com/elastic/go-elasticsearch/v9 v9.2.0/go.mod h1:2PB5YQPpY5tWbF65MRqzEXA31PZOdXCkloQSOZtU14I=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
@@ -915,8 +915,8 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/godror/godror v0.49.4 h1:8kKWKoR17nPX7u10hr4GwD4u10hzTZED9ihdkuzRrKI=
github.com/godror/godror v0.49.4/go.mod h1:kTMcxZzRw73RT5kn9v3JkBK4kHI6dqowHotqV72ebU8=
github.com/godror/godror v0.49.6 h1:ts4ZGw8uLJ42e1D7aXmVuSrld0/lzUzmIUjuUuQOgGM=
github.com/godror/godror v0.49.6/go.mod h1:kTMcxZzRw73RT5kn9v3JkBK4kHI6dqowHotqV72ebU8=
github.com/godror/knownpb v0.3.0 h1:+caUdy8hTtl7X05aPl3tdL540TvCcaQA6woZQroLZMw=
github.com/godror/knownpb v0.3.0/go.mod h1:PpTyfJwiOEAzQl7NtVCM8kdPCnp3uhxsZYIzZ5PV4zU=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
@@ -1222,8 +1222,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.3.0/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w=
github.com/redis/go-redis/v9 v9.16.0 h1:OotgqgLSRCmzfqChbQyG1PHC3tLNR89DG4jdOERSEP4=
github.com/redis/go-redis/v9 v9.16.0/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=

View File

@@ -32,16 +32,9 @@ tools:
source: cloud-sql-mysql-source
description: Lists top N (default 10) ongoing queries from processlist and innodb_trx, ordered by execution time in descending order. Returns detailed information of those queries in json format, including process id, query, transaction duration, transaction wait duration, process time, transaction state, process state, username with host, transaction rows locked, transaction rows modified, and db schema.
get_query_plan:
kind: mysql-sql
kind: mysql-get-query-plan
source: cloud-sql-mysql-source
description: "Provide information about how MySQL executes a SQL statement. Common use cases include: 1) analyze query plan to improve its performance, and 2) determine effectiveness of existing indexes and evalueate new ones."
statement: |
EXPLAIN FORMAT=JSON {{.sql_statement}};
templateParameters:
- name: sql_statement
type: string
description: "the SQL statement to explain"
required: true
list_tables:
kind: mysql-list-tables
source: cloud-sql-mysql-source

View File

@@ -36,16 +36,9 @@ tools:
source: mysql-source
description: Lists top N (default 10) ongoing queries from processlist and innodb_trx, ordered by execution time in descending order. Returns detailed information of those queries in json format, including process id, query, transaction duration, transaction wait duration, process time, transaction state, process state, username with host, transaction rows locked, transaction rows modified, and db schema.
get_query_plan:
kind: mysql-sql
kind: mysql-get-query-plan
source: mysql-source
description: "Provide information about how MySQL executes a SQL statement. Common use cases include: 1) analyze query plan to improve its performance, and 2) determine effectiveness of existing indexes and evalueate new ones."
statement: |
EXPLAIN FORMAT=JSON {{.sql_statement}};
templateParameters:
- name: sql_statement
type: string
description: "the SQL statement to explain"
required: true
list_tables:
kind: mysql-list-tables
source: mysql-source

View File

@@ -172,7 +172,14 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
accessToken := tools.AccessToken(r.Header.Get("Authorization"))
// Check if this specific tool requires the standard authorization header
if tool.RequiresClientAuthorization(s.ResourceMgr) {
clientAuth, err := tool.RequiresClientAuthorization(s.ResourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
s.logger.DebugContext(ctx, errMsg.Error())
_ = render.Render(w, r, newErrResponse(errMsg, http.StatusNotFound))
return
}
if clientAuth {
if accessToken == "" {
err = fmt.Errorf("tool requires client authorization but access token is missing from the request header")
s.logger.DebugContext(ctx, err.Error())
@@ -255,7 +262,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
}
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
if tool.RequiresClientAuthorization(s.ResourceMgr) {
if clientAuth {
// Propagate the original 401/403 error.
s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err))
_ = render.Render(w, r, newErrResponse(err, statusCode))

View File

@@ -77,9 +77,9 @@ func (t MockTool) Authorized(verifiedAuthServices []string) bool {
return !t.unauthorized
}
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) bool {
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) (bool, error) {
// defaulted to false
return t.requiresClientAuthrorization
return t.requiresClientAuthrorization, nil
}
func (t MockTool) McpManifest() tools.McpManifest {
@@ -119,8 +119,8 @@ func (t MockTool) McpManifest() tools.McpManifest {
return mcpManifest
}
func (t MockTool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t MockTool) GetAuthTokenHeaderName(tools.SourceProvider) (string, error) {
return "Authorization", nil
}
// MockPrompt is used to mock prompts in tests

View File

@@ -108,10 +108,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Get access token
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
// Check if this specific tool requires the standard authorization header
if tool.RequiresClientAuthorization(resourceMgr) {
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
if clientAuth {
if accessToken == "" {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
}
@@ -183,7 +193,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if tool.RequiresClientAuthorization(resourceMgr) {
if clientAuth {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}

View File

@@ -108,10 +108,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Get access token
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
// Check if this specific tool requires the standard authorization header
if tool.RequiresClientAuthorization(resourceMgr) {
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
if clientAuth {
if accessToken == "" {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
}
@@ -183,7 +193,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if tool.RequiresClientAuthorization(resourceMgr) {
if clientAuth {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}

View File

@@ -101,10 +101,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Get access token
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
// Check if this specific tool requires the standard authorization header
if tool.RequiresClientAuthorization(resourceMgr) {
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
if clientAuth {
if accessToken == "" {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
}
@@ -176,7 +186,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if tool.RequiresClientAuthorization(resourceMgr) {
if clientAuth {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}

View File

@@ -30,26 +30,6 @@ import (
const SourceKind string = "alloydb-admin"
type userAgentRoundTripper struct {
userAgent string
next http.RoundTripper
}
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
newReq.Header = make(http.Header)
for k, v := range req.Header {
newReq.Header[k] = v
}
ua := newReq.Header.Get("User-Agent")
if ua == "" {
newReq.Header.Set("User-Agent", rt.userAgent)
} else {
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
}
return rt.next.RoundTrip(&newReq)
}
// validate interface
var _ sources.SourceConfig = Config{}
@@ -87,10 +67,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
var client *http.Client
if r.UseClientOAuth {
client = &http.Client{
Transport: &userAgentRoundTripper{
userAgent: ua,
next: http.DefaultTransport,
},
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
}
} else {
// Use Application Default Credentials
@@ -99,10 +76,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("failed to find default credentials: %w", err)
}
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
baseClient.Transport = &userAgentRoundTripper{
userAgent: ua,
next: baseClient.Transport,
}
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
client = baseClient
}
@@ -136,6 +110,10 @@ func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) GetDefaultProject() string {
return s.DefaultProject
}
func (s *Source) GetService(ctx context.Context, accessToken string) (*alloydbrestapi.Service, error) {
if s.UseClientOAuth {
token := &oauth2.Token{AccessToken: accessToken}

View File

@@ -29,26 +29,6 @@ import (
const SourceKind string = "cloud-gemini-data-analytics"
const Endpoint string = "https://geminidataanalytics.googleapis.com"
type userAgentRoundTripper struct {
userAgent string
next http.RoundTripper
}
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
newReq.Header = make(http.Header)
for k, v := range req.Header {
newReq.Header[k] = v
}
ua := newReq.Header.Get("User-Agent")
if ua == "" {
newReq.Header.Set("User-Agent", rt.userAgent)
} else {
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
}
return rt.next.RoundTrip(&newReq)
}
// validate interface
var _ sources.SourceConfig = Config{}
@@ -87,10 +67,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
var client *http.Client
if r.UseClientOAuth {
client = &http.Client{
Transport: &userAgentRoundTripper{
userAgent: ua,
next: http.DefaultTransport,
},
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
}
} else {
// Use Application Default Credentials
@@ -100,10 +77,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("failed to find default credentials: %w", err)
}
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
baseClient.Transport = &userAgentRoundTripper{
userAgent: ua,
next: baseClient.Transport,
}
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
client = baseClient
}
@@ -133,6 +107,14 @@ func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) GetProjectID() string {
return s.ProjectID
}
func (s *Source) GetBaseURL() string {
return s.BaseURL
}
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
if s.UseClientOAuth {
if accessToken == "" {
@@ -140,10 +122,7 @@ func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Clien
}
token := &oauth2.Token{AccessToken: accessToken}
baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
baseClient.Transport = &userAgentRoundTripper{
userAgent: s.userAgent,
next: baseClient.Transport,
}
baseClient.Transport = util.NewUserAgentRoundTripper(s.userAgent, baseClient.Transport)
return baseClient, nil
}
return s.Client, nil

View File

@@ -29,26 +29,6 @@ import (
const SourceKind string = "cloud-monitoring"
type userAgentRoundTripper struct {
userAgent string
next http.RoundTripper
}
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
newReq.Header = make(http.Header)
for k, v := range req.Header {
newReq.Header[k] = v
}
ua := newReq.Header.Get("User-Agent")
if ua == "" {
newReq.Header.Set("User-Agent", rt.userAgent)
} else {
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
}
return rt.next.RoundTrip(&newReq)
}
// validate interface
var _ sources.SourceConfig = Config{}
@@ -86,10 +66,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
var client *http.Client
if r.UseClientOAuth {
client = &http.Client{
Transport: &userAgentRoundTripper{
userAgent: ua,
next: http.DefaultTransport,
},
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
}
} else {
// Use Application Default Credentials
@@ -98,18 +75,15 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("failed to find default credentials: %w", err)
}
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
baseClient.Transport = &userAgentRoundTripper{
userAgent: ua,
next: baseClient.Transport,
}
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
client = baseClient
}
s := &Source{
Config: r,
BaseURL: "https://monitoring.googleapis.com",
Client: client,
UserAgent: ua,
baseURL: "https://monitoring.googleapis.com",
client: client,
userAgent: ua,
}
return s, nil
}
@@ -118,9 +92,9 @@ var _ sources.Source = &Source{}
type Source struct {
Config
BaseURL string `yaml:"baseUrl"`
Client *http.Client
UserAgent string
baseURL string
client *http.Client
userAgent string
}
func (s *Source) SourceKind() string {
@@ -131,6 +105,18 @@ func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) BaseURL() string {
return s.baseURL
}
func (s *Source) Client() *http.Client {
return s.client
}
func (s *Source) UserAgent() string {
return s.userAgent
}
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
if s.UseClientOAuth {
if accessToken == "" {
@@ -139,7 +125,7 @@ func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Clien
token := &oauth2.Token{AccessToken: accessToken}
return oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)), nil
}
return s.Client, nil
return s.client, nil
}
func (s *Source) UseClientAuthorization() bool {

View File

@@ -30,26 +30,6 @@ import (
const SourceKind string = "cloud-sql-admin"
type userAgentRoundTripper struct {
userAgent string
next http.RoundTripper
}
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
newReq.Header = make(http.Header)
for k, v := range req.Header {
newReq.Header[k] = v
}
ua := newReq.Header.Get("User-Agent")
if ua == "" {
newReq.Header.Set("User-Agent", rt.userAgent)
} else {
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
}
return rt.next.RoundTrip(&newReq)
}
// validate interface
var _ sources.SourceConfig = Config{}
@@ -88,10 +68,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
var client *http.Client
if r.UseClientOAuth {
client = &http.Client{
Transport: &userAgentRoundTripper{
userAgent: ua,
next: http.DefaultTransport,
},
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
}
} else {
// Use Application Default Credentials
@@ -100,10 +77,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("failed to find default credentials: %w", err)
}
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
baseClient.Transport = &userAgentRoundTripper{
userAgent: ua,
next: baseClient.Transport,
}
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
client = baseClient
}
@@ -136,6 +110,10 @@ func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) GetDefaultProject() string {
return s.DefaultProject
}
func (s *Source) GetService(ctx context.Context, accessToken string) (*sqladmin.Service, error) {
if s.UseClientOAuth {
token := &oauth2.Token{AccessToken: accessToken}

View File

@@ -107,7 +107,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
s := &Source{
Config: r,
Client: &client,
client: &client,
}
return s, nil
@@ -117,7 +117,7 @@ var _ sources.Source = &Source{}
type Source struct {
Config
Client *http.Client
client *http.Client
}
func (s *Source) SourceKind() string {
@@ -127,3 +127,19 @@ func (s *Source) SourceKind() string {
func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) HttpDefaultHeaders() map[string]string {
return s.DefaultHeaders
}
func (s *Source) HttpBaseURL() string {
return s.BaseURL
}
func (s *Source) HttpQueryParams() map[string]string {
return s.QueryParams
}
func (s *Source) Client() *http.Client {
return s.client
}

View File

@@ -160,10 +160,6 @@ func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) GetApiSettings() *rtl.ApiSettings {
return s.ApiSettings
}
func (s *Source) UseClientAuthorization() bool {
return strings.ToLower(s.UseClientOAuth) != "false"
}
@@ -188,6 +184,30 @@ func (s *Source) GoogleCloudTokenSourceWithScope(ctx context.Context, scope stri
return google.DefaultTokenSource(ctx, scope)
}
func (s *Source) LookerClient() *v4.LookerSDK {
return s.Client
}
func (s *Source) LookerApiSettings() *rtl.ApiSettings {
return s.ApiSettings
}
func (s *Source) LookerShowHiddenFields() bool {
return s.ShowHiddenFields
}
func (s *Source) LookerShowHiddenModels() bool {
return s.ShowHiddenModels
}
func (s *Source) LookerShowHiddenExplores() bool {
return s.ShowHiddenExplores
}
func (s *Source) LookerSessionLength() int64 {
return s.SessionLength
}
func initGoogleCloudConnection(ctx context.Context) (oauth2.TokenSource, error) {
cred, err := google.FindDefaultCredentials(ctx, geminidataanalytics.DefaultAuthScopes()...)
if err != nil {

View File

@@ -96,6 +96,14 @@ func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) GetProject() string {
return s.Project
}
func (s *Source) GetLocation() string {
return s.Location
}
func (s *Source) GetBatchControllerClient() *dataproc.BatchControllerClient {
return s.Client
}

View File

@@ -20,7 +20,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the create-cluster tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -97,7 +102,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -107,7 +111,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the create-cluster tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
@@ -120,6 +123,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
if !ok || project == "" {
@@ -151,7 +159,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'user' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -198,10 +206,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,7 +20,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the create-instance tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -98,7 +103,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -108,7 +112,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the create-instance tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
@@ -121,6 +124,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
if !ok || project == "" {
@@ -147,7 +155,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'instanceType' parameter; expected 'PRIMARY' or 'READ_POOL'")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -208,10 +216,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,7 +20,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
@@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the create-user tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -98,7 +103,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -108,9 +112,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the create-user tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -121,6 +123,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
if !ok || project == "" {
@@ -147,7 +154,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing 'userType' parameter; expected 'ALLOYDB_BUILT_IN' or 'ALLOYDB_IAM_USER'")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -208,10 +215,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-get-cluster"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the get-cluster tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -104,7 +109,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the get-cluster tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters
manifest tools.Manifest
@@ -117,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -132,7 +141,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -167,10 +176,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-get-instance"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the get-instance tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -95,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -105,9 +110,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the get-instance tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters
AllParams parameters.Parameters
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -137,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'instance' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -172,10 +180,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-get-user"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the get-user tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -95,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -105,9 +110,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the get-user tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters
AllParams parameters.Parameters
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -118,6 +121,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -137,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'user' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -172,10 +180,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-list-clusters"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the list-clusters tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -93,7 +99,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -103,9 +108,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the list-clusters tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -116,6 +119,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -127,7 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -162,10 +170,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-list-instances"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the list-instances tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -104,9 +109,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the list-instances tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -117,6 +120,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -132,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -167,10 +175,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,9 +20,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-list-users"
@@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Configuration for the list-users tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -66,12 +72,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -94,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -104,9 +109,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the list-users tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -117,6 +120,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -132,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -167,10 +175,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -25,9 +25,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-wait-for-operation"
@@ -89,6 +89,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
}
// Config defines the configuration for the wait-for-operation tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
@@ -119,12 +125,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(*alloydbadmin.Source)
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.DefaultProject
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
@@ -180,7 +186,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
Source: s,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -194,19 +199,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Tool represents the wait-for-operation tool.
type Tool struct {
Config
Source *alloydbadmin.Source
AllParams parameters.Parameters `yaml:"allParams"`
AllParams parameters.Parameters `yaml:"allParams"`
Client *http.Client
manifest tools.Manifest
mcpManifest tools.McpManifest
// Polling configuration
Delay time.Duration
MaxDelay time.Duration
Multiplier float64
MaxRetries int
Client *http.Client
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -215,6 +217,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -230,7 +237,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("missing 'operation' parameter")
}
service, err := t.Source.GetService(ctx, string(accessToken))
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
@@ -363,10 +370,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -21,7 +21,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"github.com/jackc/pgx/v5/pgxpool"
@@ -47,11 +46,6 @@ type compatibleSource interface {
PostgresPool() *pgxpool.Pool
}
// validate compatible sources are still compatible
var _ compatibleSource = &alloydbpg.Source{}
var compatibleSources = [...]string{alloydbpg.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -70,18 +64,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
numParams := len(cfg.NLConfigParameters)
quotedNameParts := make([]string, 0, numParams)
placeholderParts := make([]string, 0, numParams)
@@ -126,7 +108,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
Config: cfg,
Parameters: cfg.NLConfigParameters,
Statement: stmt,
Pool: s.PostgresPool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.NLConfigParameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -139,9 +120,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters parameters.Parameters `yaml:"parameters"`
Pool *pgxpool.Pool
Parameters parameters.Parameters `yaml:"parameters"`
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
@@ -152,6 +131,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
pool := source.PostgresPool()
sliceParams := params.AsSlice()
allParamValues := make([]any, len(sliceParams)+1)
allParamValues[0] = fmt.Sprintf("%s", sliceParams[0]) // nl_question
@@ -160,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
allParamValues[i+2] = fmt.Sprintf("%s", param)
}
results, err := t.Pool.Query(ctx, t.Statement, allParamValues...)
results, err := pool.Query(ctx, t.Statement, allParamValues...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w. Query: %v , Values: %v. Toolbox v0.19.0+ is only compatible with AlloyDB AI NL v1.0.3+. Please ensure that you are using the latest AlloyDB AI NL extension", err, t.Statement, allParamValues)
}
@@ -203,10 +188,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -57,11 +57,6 @@ type compatibleSource interface {
BigQuerySession() bigqueryds.BigQuerySessionProvider
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -87,7 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
allowedDatasets := s.BigQueryAllowedDatasets()
@@ -136,17 +131,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
SessionProvider: s.BigQuerySession(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -156,17 +144,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
SessionProvider bigqueryds.BigQuerySessionProvider
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -175,23 +155,27 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke runs the contribution analysis.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
inputData, ok := paramsMap["input_data"].(string)
if !ok {
return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"])
}
bqClient := t.Client
restService := t.RestService
var err error
bqClient := source.BigQueryClient()
restService := source.BigQueryRestService()
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = t.ClientCreator(tokenStr, true)
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -229,9 +213,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
var inputDataSource string
trimmedUpperInputData := strings.TrimSpace(strings.ToUpper(inputData))
if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") {
if len(t.AllowedDatasets) > 0 {
if len(source.BigQueryAllowedDatasets()) > 0 {
var connProps []*bigqueryapi.ConnectionProperty
session, err := t.SessionProvider(ctx)
session, err := source.BigQuerySession()(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
@@ -240,7 +224,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
{Key: "session_id", Value: session.ID},
}
}
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, inputData, nil, connProps)
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps)
if err != nil {
return nil, fmt.Errorf("query validation failed: %w", err)
}
@@ -252,7 +236,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
queryStats := dryRunJob.Statistics.Query
if queryStats != nil {
for _, tableRef := range queryStats.ReferencedTables {
if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
return nil, fmt.Errorf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
}
}
@@ -262,18 +246,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
inputDataSource = fmt.Sprintf("(%s)", inputData)
} else {
if len(t.AllowedDatasets) > 0 {
if len(source.BigQueryAllowedDatasets()) > 0 {
parts := strings.Split(inputData, ".")
var projectID, datasetID string
switch len(parts) {
case 3: // project.dataset.table
projectID, datasetID = parts[0], parts[1]
case 2: // dataset.table
projectID, datasetID = t.Client.Project(), parts[0]
projectID, datasetID = source.BigQueryClient().Project(), parts[0]
default:
return nil, fmt.Errorf("invalid table ID format for 'input_data': %q. Expected 'dataset.table' or 'project.dataset.table'", inputData)
}
if !t.IsDatasetAllowed(projectID, datasetID) {
if !source.IsDatasetAllowed(projectID, datasetID) {
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, inputData)
}
}
@@ -292,7 +276,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// Get session from provider if in protected mode.
// Otherwise, a new session will be created by the first query.
session, err := t.SessionProvider(ctx)
session, err := source.BigQuerySession()(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
@@ -385,10 +369,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -26,7 +26,6 @@ import (
bigqueryapi "cloud.google.com/go/bigquery"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -105,11 +104,6 @@ type CAPayload struct {
ClientIdEnum string `json:"clientIdEnum"`
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -135,7 +129,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
allowedDatasets := s.BigQueryAllowedDatasets()
@@ -153,31 +147,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
params := parameters.Parameters{userQueryParameter, tableRefsParameter}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
// Get cloud-platform token source for Gemini Data Analytics API during initialization
var bigQueryTokenSourceWithScope oauth2.TokenSource
if !s.UseClientAuthorization() {
ctx := context.Background()
ts, err := s.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err)
}
bigQueryTokenSourceWithScope = ts
}
// finish tool setup
t := Tool{
Config: cfg,
Project: s.BigQueryProject(),
Location: s.BigQueryLocation(),
Parameters: params,
Client: s.BigQueryClient(),
UseClientOAuth: s.UseClientAuthorization(),
TokenSource: bigQueryTokenSourceWithScope,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
MaxQueryResultRows: s.GetMaxQueryResultRows(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -187,18 +162,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project string
Location string
Client *bigqueryapi.Client
TokenSource oauth2.TokenSource
manifest tools.Manifest
mcpManifest tools.McpManifest
MaxQueryResultRows int
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -206,11 +172,15 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
var tokenStr string
var err error
// Get credentials for the API call
if t.UseClientOAuth {
if source.UseClientAuthorization() {
// Use client-side access token
if accessToken == "" {
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", util.ErrUnauthorized)
@@ -220,11 +190,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("error parsing access token: %w", err)
}
} else {
// Get cloud-platform token source for Gemini Data Analytics API during initialization
tokenSource, err := source.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err)
}
// Use cloud-platform token source for Gemini Data Analytics API
if t.TokenSource == nil {
if tokenSource == nil {
return nil, fmt.Errorf("cloud-platform token source is missing")
}
token, err := t.TokenSource.Token()
token, err := tokenSource.Token()
if err != nil {
return nil, fmt.Errorf("failed to get token from cloud-platform token source: %w", err)
}
@@ -245,17 +221,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
if len(t.AllowedDatasets) > 0 {
if len(source.BigQueryAllowedDatasets()) > 0 {
for _, tableRef := range tableRefs {
if !t.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
if !source.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID)
}
}
}
// Construct URL, headers, and payload
projectID := t.Project
location := t.Location
projectID := source.BigQueryProject()
location := source.BigQueryLocation()
if location == "" {
location = "us"
}
@@ -279,7 +255,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
// Call the streaming API
response, err := getStream(caURL, payload, headers, t.MaxQueryResultRows)
response, err := getStream(caURL, payload, headers, source.GetMaxQueryResultRows())
if err != nil {
return nil, fmt.Errorf("failed to get response from conversational analytics API: %w", err)
}
@@ -303,8 +279,12 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
// StreamMessage represents a single message object from the streaming API response.
@@ -580,6 +560,6 @@ func appendMessage(messages []map[string]any, newMessage map[string]any) []map[s
return append(messages, newMessage)
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -60,11 +60,6 @@ type compatibleSource interface {
BigQueryAllowedDatasets() []string
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -90,7 +85,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
var sqlDescriptionBuilder strings.Builder
@@ -136,18 +131,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
WriteMode: s.BigQueryWriteMode(),
SessionProvider: s.BigQuerySession(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -157,18 +144,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
WriteMode string
SessionProvider bigqueryds.BigQuerySessionProvider
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -176,6 +154,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
sql, ok := paramsMap["sql"].(string)
if !ok {
@@ -186,17 +169,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"])
}
bqClient := t.Client
restService := t.RestService
bqClient := source.BigQueryClient()
restService := source.BigQueryRestService()
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = t.ClientCreator(tokenStr, true)
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -204,8 +186,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
var connProps []*bigqueryapi.ConnectionProperty
var session *bigqueryds.Session
if t.WriteMode == bigqueryds.WriteModeProtected {
session, err = t.SessionProvider(ctx)
if source.BigQueryWriteMode() == bigqueryds.WriteModeProtected {
session, err = source.BigQuerySession()(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session for protected mode: %w", err)
}
@@ -221,7 +203,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
statementType := dryRunJob.Statistics.Query.StatementType
switch t.WriteMode {
switch source.BigQueryWriteMode() {
case bigqueryds.WriteModeBlocked:
if statementType != "SELECT" {
return nil, fmt.Errorf("write mode is 'blocked', only SELECT statements are allowed")
@@ -235,7 +217,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
if len(t.AllowedDatasets) > 0 {
if len(source.BigQueryAllowedDatasets()) > 0 {
switch statementType {
case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA":
return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType)
@@ -270,7 +252,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
} else if statementType != "SELECT" {
// If dry run yields no tables, fall back to the parser for non-SELECT statements
// to catch unsafe operations like EXECUTE IMMEDIATE.
parsedTables, parseErr := bqutil.TableParser(sql, t.Client.Project())
parsedTables, parseErr := bqutil.TableParser(sql, source.BigQueryClient().Project())
if parseErr != nil {
// If parsing fails (e.g., EXECUTE IMMEDIATE), we cannot guarantee safety, so we must fail.
return nil, fmt.Errorf("could not parse tables from query to validate against allowed datasets: %w", parseErr)
@@ -282,7 +264,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
parts := strings.Split(tableID, ".")
if len(parts) == 3 {
projectID, datasetID := parts[0], parts[1]
if !t.IsDatasetAllowed(projectID, datasetID) {
if !source.IsDatasetAllowed(projectID, datasetID) {
return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID)
}
}
@@ -374,10 +356,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -57,11 +57,6 @@ type compatibleSource interface {
BigQuerySession() bigqueryds.BigQuerySessionProvider
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -87,7 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
allowedDatasets := s.BigQueryAllowedDatasets()
@@ -116,17 +111,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
IsDatasetAllowed: s.IsDatasetAllowed,
SessionProvider: s.BigQuerySession(),
AllowedDatasets: allowedDatasets,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -136,17 +124,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
SessionProvider bigqueryds.BigQuerySessionProvider
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -154,6 +134,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
historyData, ok := paramsMap["history_data"].(string)
if !ok {
@@ -188,17 +173,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
bqClient := t.Client
restService := t.RestService
var err error
bqClient := source.BigQueryClient()
restService := source.BigQueryRestService()
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = t.ClientCreator(tokenStr, false)
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -207,9 +191,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
var historyDataSource string
trimmedUpperHistoryData := strings.TrimSpace(strings.ToUpper(historyData))
if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") {
if len(t.AllowedDatasets) > 0 {
if len(source.BigQueryAllowedDatasets()) > 0 {
var connProps []*bigqueryapi.ConnectionProperty
session, err := t.SessionProvider(ctx)
session, err := source.BigQuerySession()(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
@@ -218,7 +202,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
{Key: "session_id", Value: session.ID},
}
}
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, historyData, nil, connProps)
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps)
if err != nil {
return nil, fmt.Errorf("query validation failed: %w", err)
}
@@ -230,7 +214,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
queryStats := dryRunJob.Statistics.Query
if queryStats != nil {
for _, tableRef := range queryStats.ReferencedTables {
if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
return nil, fmt.Errorf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
}
}
@@ -240,7 +224,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
historyDataSource = fmt.Sprintf("(%s)", historyData)
} else {
if len(t.AllowedDatasets) > 0 {
if len(source.BigQueryAllowedDatasets()) > 0 {
parts := strings.Split(historyData, ".")
var projectID, datasetID string
@@ -249,13 +233,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
projectID = parts[0]
datasetID = parts[1]
case 2: // dataset.table
projectID = t.Client.Project()
projectID = source.BigQueryClient().Project()
datasetID = parts[0]
default:
return nil, fmt.Errorf("invalid table ID format for 'history_data': %q. Expected 'dataset.table' or 'project.dataset.table'", historyData)
}
if !t.IsDatasetAllowed(projectID, datasetID) {
if !source.IsDatasetAllowed(projectID, datasetID) {
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, historyData)
}
}
@@ -279,7 +263,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// JobStatistics.QueryStatistics.StatementType
query := bqClient.Query(sql)
query.Location = bqClient.Location
session, err := t.SessionProvider(ctx)
session, err := source.BigQuerySession()(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
@@ -349,10 +333,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -54,11 +54,6 @@ type compatibleSource interface {
BigQueryAllowedDatasets() []string
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -84,7 +79,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
defaultProjectID := s.BigQueryProject()
@@ -104,14 +99,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
IsDatasetAllowed: s.IsDatasetAllowed,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -121,15 +112,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
Statement string
IsDatasetAllowed func(projectID, datasetID string) bool
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -137,6 +122,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
mapParams := params.AsMap()
projectId, ok := mapParams[projectKey].(string)
if !ok {
@@ -148,22 +138,21 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
}
bqClient := t.Client
var err error
bqClient := source.BigQueryClient()
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}
if !t.IsDatasetAllowed(projectId, datasetId) {
if !source.IsDatasetAllowed(projectId, datasetId) {
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
}
@@ -193,10 +182,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -55,11 +55,6 @@ type compatibleSource interface {
BigQueryAllowedDatasets() []string
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -85,7 +80,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
defaultProjectID := s.BigQueryProject()
@@ -108,14 +103,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
IsDatasetAllowed: s.IsDatasetAllowed,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -125,15 +116,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
Statement string
IsDatasetAllowed func(projectID, datasetID string) bool
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -141,6 +126,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
mapParams := params.AsMap()
projectId, ok := mapParams[projectKey].(string)
if !ok {
@@ -157,20 +147,19 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", tableKey)
}
if !t.IsDatasetAllowed(projectId, datasetId) {
if !source.IsDatasetAllowed(projectId, datasetId) {
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
}
bqClient := t.Client
bqClient := source.BigQueryClient()
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -203,10 +192,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -52,11 +52,6 @@ type compatibleSource interface {
BigQueryAllowedDatasets() []string
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -82,7 +77,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
var projectParameter parameters.Parameter
@@ -103,14 +98,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
AllowedDatasets: allowedDatasets,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -120,15 +111,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
Statement string
AllowedDatasets []string
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -136,8 +121,13 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
if len(t.AllowedDatasets) > 0 {
return t.AllowedDatasets, nil
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
if len(source.BigQueryAllowedDatasets()) > 0 {
return source.BigQueryAllowedDatasets(), nil
}
mapParams := params.AsMap()
projectId, ok := mapParams[projectKey].(string)
@@ -145,14 +135,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
}
bqClient := t.Client
bqClient := source.BigQueryClient()
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -197,10 +187,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -55,11 +55,6 @@ type compatibleSource interface {
BigQueryAllowedDatasets() []string
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -85,7 +80,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
defaultProjectID := s.BigQueryProject()
@@ -107,14 +102,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
IsDatasetAllowed: s.IsDatasetAllowed,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -124,15 +115,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -140,6 +125,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
mapParams := params.AsMap()
projectId, ok := mapParams[projectKey].(string)
if !ok {
@@ -151,18 +141,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
}
if !t.IsDatasetAllowed(projectId, datasetId) {
if !source.IsDatasetAllowed(projectId, datasetId) {
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
}
bqClient := t.Client
bqClient := source.BigQueryClient()
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -208,10 +198,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -51,11 +51,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -72,20 +67,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// Initialize the search configuration with the provided sources
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
// Get the Dataplex client using the method from the source
makeCatalogClient := s.MakeDataplexCatalogClient()
prompt := parameters.NewStringParameter("prompt", "Prompt representing search intention. Do not rewrite the prompt.")
datasetIds := parameters.NewArrayParameterWithDefault("datasetIds", []any{}, "Array of dataset IDs.", parameters.NewStringParameter("datasetId", "The IDs of the bigquery dataset."))
projectIds := parameters.NewArrayParameterWithDefault("projectIds", []any{}, "Array of project IDs.", parameters.NewStringParameter("projectId", "The IDs of the bigquery project."))
@@ -100,11 +81,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, params, nil)
t := Tool{
Config: cfg,
Parameters: params,
UseClientOAuth: s.UseClientAuthorization(),
MakeCatalogClient: makeCatalogClient,
ProjectID: s.BigQueryProject(),
Config: cfg,
Parameters: params,
manifest: tools.Manifest{
Description: cfg.Description,
Parameters: params.Manifest(),
@@ -117,12 +95,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
type Tool struct {
Config
Parameters parameters.Parameters
UseClientOAuth bool
MakeCatalogClient func() (*dataplexapi.CatalogClient, bigqueryds.DataplexClientCreator, error)
ProjectID string
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -133,8 +108,12 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func constructSearchQueryHelper(predicate string, operator string, items []string) string {
@@ -207,6 +186,11 @@ func ExtractType(resourceString string) string {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
pageSize := int32(paramsMap["pageSize"].(int))
prompt, _ := paramsMap["prompt"].(string)
@@ -228,14 +212,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
req := &dataplexpb.SearchEntriesRequest{
Query: fmt.Sprintf("%s %s", prompt, constructSearchQuery(projectIds, datasetIds, types)),
Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID),
Name: fmt.Sprintf("projects/%s/locations/global", source.BigQueryProject()),
PageSize: pageSize,
SemanticSearch: true,
}
catalogClient, dataplexClientCreator, _ := t.MakeCatalogClient()
catalogClient, dataplexClientCreator, _ := source.MakeDataplexCatalogClient()()
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
@@ -248,7 +232,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
it := catalogClient.SearchEntries(ctx, req)
if it == nil {
return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID)
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.BigQueryProject())
}
var results []Response
@@ -288,6 +272,6 @@ func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -57,11 +57,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -81,18 +76,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
if err != nil {
return nil, err
@@ -102,15 +85,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
AllParams: allParameters,
UseClientOAuth: s.UseClientAuthorization(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
SessionProvider: s.BigQuerySession(),
ClientCreator: s.BigQueryClientCreator(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -120,15 +98,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
AllParams parameters.Parameters `yaml:"allParams"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
SessionProvider bigqueryds.BigQuerySessionProvider
ClientCreator bigqueryds.BigqueryClientCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -136,6 +108,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters))
@@ -212,16 +189,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
lowLevelParams = append(lowLevelParams, lowLevelParam)
}
bqClient := t.Client
restService := t.RestService
bqClient := source.BigQueryClient()
restService := source.BigQueryRestService()
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = t.ClientCreator(tokenStr, true)
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
@@ -232,8 +209,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
query.Location = bqClient.Location
connProps := []*bigqueryapi.ConnectionProperty{}
if t.SessionProvider != nil {
session, err := t.SessionProvider(ctx)
if source.BigQuerySession() != nil {
session, err := source.BigQuerySession()(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
@@ -311,10 +288,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -21,7 +21,6 @@ import (
"cloud.google.com/go/bigtable"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
bigtabledb "github.com/googleapis/genai-toolbox/internal/sources/bigtable"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -46,11 +45,6 @@ type compatibleSource interface {
BigtableClient() *bigtable.Client
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigtabledb.Source{}
var compatibleSources = [...]string{bigtabledb.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -70,18 +64,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
if err != nil {
return nil, err
@@ -93,7 +75,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
t := Tool{
Config: cfg,
AllParams: allParameters,
Client: s.BigtableClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -105,9 +86,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
Client *bigtable.Client
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -156,6 +135,11 @@ func getMapParamsType(tparams parameters.Parameters, params parameters.ParamValu
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {
@@ -172,7 +156,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("fail to get map params: %w", err)
}
ps, err := t.Client.PrepareStatement(
ps, err := source.BigtableClient().PrepareStatement(
ctx,
newStatement,
mapParamsType,
@@ -224,10 +208,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -21,7 +21,6 @@ import (
gocql "github.com/apache/cassandra-gocql-driver/v2"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cassandra"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -46,10 +45,6 @@ type compatibleSource interface {
CassandraSession() *gocql.Session
}
var _ compatibleSource = &cassandra.Source{}
var compatibleSources = [...]string{cassandra.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -61,20 +56,15 @@ type Config struct {
TemplateParameters parameters.Parameters `yaml:"templateParameters"`
}
var _ tools.ToolConfig = Config{}
// ToolConfigKind implements tools.ToolConfig.
func (c Config) ToolConfigKind() string {
return kind
}
// Initialize implements tools.ToolConfig.
func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[c.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", c.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
allParameters, paramManifest, err := parameters.ProcessParameters(c.TemplateParameters, c.Parameters)
if err != nil {
return nil, err
@@ -85,25 +75,17 @@ func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
t := Tool{
Config: c,
AllParams: allParameters,
Session: s.CassandraSession(),
manifest: tools.Manifest{Description: c.Description, Parameters: paramManifest, AuthRequired: c.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
// ToolConfigKind implements tools.ToolConfig.
func (c Config) ToolConfigKind() string {
return kind
}
var _ tools.ToolConfig = Config{}
var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
Session *gocql.Session
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -113,8 +95,8 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
// RequiresClientAuthorization implements tools.Tool.
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
// Authorized implements tools.Tool.
@@ -124,6 +106,11 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
// Invoke implements tools.Tool.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {
@@ -135,7 +122,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("unable to extract standard params %w", err)
}
sliceParams := newParams.AsSlice()
iter := t.Session.Query(newStatement, sliceParams...).IterContext(ctx)
iter := source.CassandraSession().Query(newStatement, sliceParams...).IterContext(ctx)
// Create a slice to store the out
var out []map[string]interface{}
@@ -170,8 +157,6 @@ func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any)
return parameters.ParseParams(t.AllParams, data, claims)
}
var _ tools.Tool = Tool{}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -25,12 +25,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
type compatibleSource interface {
ClickHousePool() *sql.DB
}
var compatibleSources = []string{"clickhouse"}
const executeSQLKind string = "clickhouse-execute-sql"
func init() {
@@ -47,6 +41,10 @@ func newExecuteSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder
return actual, nil
}
type compatibleSource interface {
ClickHousePool() *sql.DB
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -62,16 +60,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", executeSQLKind, compatibleSources)
}
sqlParameter := parameters.NewStringParameter("sql", "The SQL statement to execute.")
params := parameters.Parameters{sqlParameter}
@@ -80,7 +68,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
t := Tool{
Config: cfg,
Parameters: params,
Pool: s.ClickHousePool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -91,9 +78,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters parameters.Parameters `yaml:"parameters"`
Pool *sql.DB
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -103,13 +88,18 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
sql, ok := paramsMap["sql"].(string)
if !ok {
return nil, fmt.Errorf("unable to cast sql parameter %s", paramsMap["sql"])
}
results, err := t.Pool.QueryContext(ctx, sql)
results, err := source.ClickHousePool().QueryContext(ctx, sql)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
@@ -183,10 +173,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -25,12 +25,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
type compatibleSource interface {
ClickHousePool() *sql.DB
}
var compatibleSources = []string{"clickhouse"}
const listDatabasesKind string = "clickhouse-list-databases"
func init() {
@@ -47,6 +41,10 @@ func newListDatabasesConfig(ctx context.Context, name string, decoder *yaml.Deco
return actual, nil
}
type compatibleSource interface {
ClickHousePool() *sql.DB
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -63,23 +61,12 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", listDatabasesKind, compatibleSources)
}
allParameters, paramManifest, _ := parameters.ProcessParameters(nil, cfg.Parameters)
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
t := Tool{
Config: cfg,
AllParams: allParameters,
Pool: s.ClickHousePool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -90,9 +77,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
Pool *sql.DB
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -102,10 +87,15 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
// Query to list all databases
query := "SHOW DATABASES"
results, err := t.Pool.QueryContext(ctx, query)
results, err := source.ClickHousePool().QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
@@ -146,10 +136,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,7 +20,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -32,21 +31,6 @@ func TestListDatabasesConfigToolConfigKind(t *testing.T) {
}
}
func TestListDatabasesConfigInitializeMissingSource(t *testing.T) {
cfg := Config{
Name: "test-list-databases",
Kind: listDatabasesKind,
Source: "missing-source",
Description: "Test list databases tool",
}
srcs := map[string]sources.Source{}
_, err := cfg.Initialize(srcs)
if err == nil {
t.Error("expected error for missing source")
}
}
func TestParseFromYamlClickHouseListDatabases(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {

View File

@@ -25,12 +25,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
type compatibleSource interface {
ClickHousePool() *sql.DB
}
var compatibleSources = []string{"clickhouse"}
const listTablesKind string = "clickhouse-list-tables"
const databaseKey string = "database"
@@ -48,6 +42,10 @@ func newListTablesConfig(ctx context.Context, name string, decoder *yaml.Decoder
return actual, nil
}
type compatibleSource interface {
ClickHousePool() *sql.DB
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -64,16 +62,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", listTablesKind, compatibleSources)
}
databaseParameter := parameters.NewStringParameter(databaseKey, "The database to list tables from.")
params := parameters.Parameters{databaseParameter}
@@ -83,7 +71,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
t := Tool{
Config: cfg,
AllParams: allParameters,
Pool: s.ClickHousePool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -94,9 +81,7 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
Pool *sql.DB
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -106,6 +91,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
mapParams := params.AsMap()
database, ok := mapParams[databaseKey].(string)
if !ok {
@@ -115,7 +105,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// Query to list all tables in the specified database
query := fmt.Sprintf("SHOW TABLES FROM %s", database)
results, err := t.Pool.QueryContext(ctx, query)
results, err := source.ClickHousePool().QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
@@ -157,10 +147,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -20,7 +20,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -32,21 +31,6 @@ func TestListTablesConfigToolConfigKind(t *testing.T) {
}
}
func TestListTablesConfigInitializeMissingSource(t *testing.T) {
cfg := Config{
Name: "test-list-tables",
Kind: listTablesKind,
Source: "missing-source",
Description: "Test list tables tool",
}
srcs := map[string]sources.Source{}
_, err := cfg.Initialize(srcs)
if err == nil {
t.Error("expected error for missing source")
}
}
func TestParseFromYamlClickHouseListTables(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {

View File

@@ -25,21 +25,15 @@ import (
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
type compatibleSource interface {
ClickHousePool() *sql.DB
}
var compatibleSources = []string{"clickhouse"}
const sqlKind string = "clickhouse-sql"
func init() {
if !tools.Register(sqlKind, newSQLConfig) {
if !tools.Register(sqlKind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", sqlKind))
}
}
func newSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
@@ -47,6 +41,10 @@ func newSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tool
return actual, nil
}
type compatibleSource interface {
ClickHousePool() *sql.DB
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -65,23 +63,12 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", sqlKind, compatibleSources)
}
allParameters, paramManifest, _ := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
t := Tool{
Config: cfg,
AllParams: allParameters,
Pool: s.ClickHousePool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -93,7 +80,6 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
Pool *sql.DB
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -103,6 +89,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {
@@ -115,7 +106,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
sliceParams := newParams.AsSlice()
results, err := t.Pool.QueryContext(ctx, newStatement, sliceParams...)
results, err := source.ClickHousePool().QueryContext(ctx, newStatement, sliceParams...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
@@ -191,10 +182,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return false
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -142,66 +142,6 @@ func TestSQLConfigInitializeValidSource(t *testing.T) {
}
}
func TestSQLConfigInitializeMissingSource(t *testing.T) {
config := Config{
Name: "test-tool",
Kind: sqlKind,
Source: "missing-source",
Description: "Test tool",
Statement: "SELECT 1",
Parameters: parameters.Parameters{},
}
sources := map[string]sources.Source{}
_, err := config.Initialize(sources)
if err == nil {
t.Fatal("Expected error for missing source, got nil")
}
expectedErr := `no source named "missing-source" configured`
if err.Error() != expectedErr {
t.Errorf("Expected error %q, got %q", expectedErr, err.Error())
}
}
// mockIncompatibleSource is a mock source that doesn't implement the compatibleSource interface
type mockIncompatibleSource struct{}
func (m *mockIncompatibleSource) SourceKind() string {
return "mock"
}
func (m *mockIncompatibleSource) ToConfig() sources.SourceConfig {
return nil
}
func TestSQLConfigInitializeIncompatibleSource(t *testing.T) {
config := Config{
Name: "test-tool",
Kind: sqlKind,
Source: "incompatible-source",
Description: "Test tool",
Statement: "SELECT 1",
Parameters: parameters.Parameters{},
}
mockSource := &mockIncompatibleSource{}
sources := map[string]sources.Source{
"incompatible-source": mockSource,
}
_, err := config.Initialize(sources)
if err == nil {
t.Fatal("Expected error for incompatible source, got nil")
}
if err.Error() == "" {
t.Error("Expected non-empty error message")
}
}
func TestToolManifest(t *testing.T) {
tool := Tool{
manifest: tools.Manifest{

View File

@@ -24,7 +24,6 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
type compatibleSource interface {
GetProjectID() string
GetBaseURL() string
UseClientAuthorization() bool
GetClient(context.Context, string) (*http.Client, error)
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -64,18 +70,6 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(*cloudgdasrc.Source)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-gemini-data-analytics`", kind)
}
// Define the parameters for the Gemini Data Analytics Query API
// The prompt is the only input parameter.
allParameters := parameters.Parameters{
@@ -87,7 +81,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return Tool{
Config: cfg,
AllParams: allParameters,
Source: s,
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}, nil
@@ -99,7 +92,6 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters
Source *cloudgdasrc.Source
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -110,6 +102,11 @@ func (t Tool) ToConfig() tools.ToolConfig {
// Invoke executes the tool logic
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
prompt, ok := paramsMap["prompt"].(string)
if !ok {
@@ -118,11 +115,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// The API endpoint itself always uses the "global" location.
apiLocation := "global"
apiParent := fmt.Sprintf("projects/%s/locations/%s", t.Source.ProjectID, apiLocation)
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", t.Source.BaseURL, apiParent)
apiParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), apiLocation)
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", source.GetBaseURL(), apiParent)
// The parent in the request payload uses the tool's configured location.
payloadParent := fmt.Sprintf("projects/%s/locations/%s", t.Source.ProjectID, t.Location)
payloadParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), t.Location)
payload := &QueryDataRequest{
Parent: payloadParent,
@@ -138,7 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// Parse the access token if provided
var tokenStr string
if t.RequiresClientAuthorization(resourceMgr) {
if source.UseClientAuthorization() {
var err error
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
@@ -146,7 +143,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
client, err := t.Source.GetClient(ctx, tokenStr)
client, err := source.GetClient(ctx, tokenStr)
if err != nil {
return nil, fmt.Errorf("failed to get HTTP client: %w", err)
}
@@ -196,10 +193,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(_ tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -26,6 +26,7 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/server/resources"
"github.com/googleapis/genai-toolbox/internal/sources"
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
"github.com/googleapis/genai-toolbox/internal/testutils"
@@ -172,9 +173,8 @@ func TestInitialize(t *testing.T) {
}
tcs := []struct {
desc string
cfg cloudgdatool.Config
expectErr bool
desc string
cfg cloudgdatool.Config
}{
{
desc: "successful initialization",
@@ -185,29 +185,6 @@ func TestInitialize(t *testing.T) {
Description: "Test Description",
Location: "us-central1",
},
expectErr: false,
},
{
desc: "missing source",
cfg: cloudgdatool.Config{
Name: "my-gda-query-tool",
Kind: "cloud-gemini-data-analytics-query",
Source: "non-existent-source",
Description: "Test Description",
Location: "us-central1",
},
expectErr: true,
},
{
desc: "incompatible source kind",
cfg: cloudgdatool.Config{
Name: "my-gda-query-tool",
Kind: "cloud-gemini-data-analytics-query",
Source: "incompatible-source",
Description: "Test Description",
Location: "us-central1",
},
expectErr: true,
},
}
@@ -219,16 +196,11 @@ func TestInitialize(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
tool, err := tc.cfg.Initialize(srcs)
if tc.expectErr && err == nil {
t.Fatalf("expected an error but got none")
}
if !tc.expectErr && err != nil {
if err != nil {
t.Fatalf("did not expect an error but got: %v", err)
}
if !tc.expectErr {
// Basic sanity check on the returned tool
_ = tool // Avoid unused variable error
}
// Basic sanity check on the returned tool
_ = tool // Avoid unused variable error
})
}
}
@@ -361,8 +333,10 @@ func TestInvoke(t *testing.T) {
{Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"},
}
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil)
// Invoke the tool
result, err := tool.Invoke(ctx, nil, params, "") // No accessToken needed for ADC client
result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client
if err != nil {
t.Fatalf("tool invocation failed: %v", err)
}

View File

@@ -62,11 +62,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -83,35 +78,16 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
urlParameter := parameters.NewStringParameter(pageURLKey, "The full URL of the FHIR page to fetch. This would be the value of `Bundle.entry.link.url` field within the response returned from FHIR search or FHIR patient everything operations.")
params := parameters.Parameters{urlParameter}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedFHIRStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -121,14 +97,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -136,13 +107,18 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
url, ok := params.AsMap()[pageURLKey].(string)
if !ok {
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey)
}
var httpClient *http.Client
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
@@ -150,7 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tokenStr})
httpClient = oauth2.NewClient(ctx, ts)
} else {
// The t.Service object holds a client with the default credentials.
// The source.Service() object holds a client with the default credentials.
// However, the client is not exported, so we have to create a new one.
var err error
httpClient, err = google.DefaultClient(ctx, healthcare.CloudHealthcareScope)
@@ -201,10 +177,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -62,11 +62,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -92,7 +87,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
idParameter := parameters.NewStringParameter(patientIDKey, "The ID of the patient FHIR resource for which the information is required")
@@ -106,17 +101,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedFHIRStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -126,15 +114,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -142,7 +124,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
if err != nil {
return nil, err
}
@@ -151,20 +138,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey)
}
svc := t.Service
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", t.Project, t.Region, t.Dataset, storeID, patientID)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", source.Project(), source.Region(), source.DatasetID(), storeID, patientID)
var opts []googleapi.CallOption
if val, ok := params.AsMap()[typeFilterKey]; ok {
types, ok := val.([]any)
@@ -225,10 +212,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -78,11 +78,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -108,7 +103,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
params := parameters.Parameters{
@@ -140,17 +135,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedFHIRStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -160,15 +148,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -176,19 +158,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := t.Service
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
@@ -261,7 +248,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
opts = append(opts, googleapi.QueryParameter("_summary", "text"))
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.SearchType(name, "Patient", &healthcare.SearchResourcesRequest{ResourceType: "Patient"}).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to search patient resources: %w", err)
@@ -298,10 +285,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -51,11 +51,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -72,33 +67,15 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
params := parameters.Parameters{}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -108,13 +85,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -122,22 +95,26 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
svc := t.Service
var err error
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset)
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
@@ -161,10 +138,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -53,11 +53,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
params := parameters.Parameters{}
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedDICOMStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := t.Service
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.DicomStores.Get(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get DICOM store %q: %w", storeName, err)
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -53,11 +53,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
params := parameters.Parameters{}
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedDICOMStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := t.Service
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.DicomStores.GetDICOMStoreMetrics(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get metrics for DICOM store %q: %w", storeName, err)
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -59,11 +59,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -89,7 +84,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
typeParameter := parameters.NewStringParameter(typeKey, "The FHIR resource type to retrieve (e.g., Patient, Observation).")
@@ -102,17 +97,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedFHIRStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -122,15 +110,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -138,7 +120,12 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
if err != nil {
return nil, err
}
@@ -152,20 +139,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey)
}
svc := t.Service
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", t.Project, t.Region, t.Dataset, storeID, resType, resID)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", source.Project(), source.Region(), source.DatasetID(), storeID, resType, resID)
call := svc.Projects.Locations.Datasets.FhirStores.Fhir.Read(name)
call.Header().Set("Content-Type", "application/fhir+json;charset=utf-8")
resp, err := call.Do()
@@ -204,10 +191,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -53,11 +53,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
params := parameters.Parameters{}
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedFHIRStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := t.Service
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID)
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.FhirStores.Get(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get FHIR store %q: %w", storeName, err)
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -53,11 +53,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
params := parameters.Parameters{}
@@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedFHIRStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -114,15 +102,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -130,25 +112,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := t.Service
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID)
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.FhirStores.GetFHIRStoreMetrics(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get metrics for FHIR store %q: %w", storeName, err)
@@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -53,11 +53,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -74,34 +69,15 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
params := parameters.Parameters{}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedDICOMStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -111,15 +87,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -127,29 +97,33 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
svc := t.Service
var err error
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset)
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
}
var filtered []*healthcare.DicomStore
for _, store := range stores.DicomStores {
if len(t.AllowedStores) == 0 {
if len(source.AllowedDICOMStores()) == 0 {
filtered = append(filtered, store)
continue
}
@@ -157,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
continue
}
parts := strings.Split(store.Name, "/")
if _, ok := t.AllowedStores[parts[len(parts)-1]]; ok {
if _, ok := source.AllowedDICOMStores()[parts[len(parts)-1]]; ok {
filtered = append(filtered, store)
}
}
@@ -180,10 +154,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -53,11 +53,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -74,34 +69,15 @@ func (cfg Config) ToolConfigKind() string {
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
params := parameters.Parameters{}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedFHIRStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -111,15 +87,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -127,29 +97,33 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
svc := t.Service
var err error
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset)
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
}
var filtered []*healthcare.FhirStore
for _, store := range stores.FhirStores {
if len(t.AllowedStores) == 0 {
if len(source.AllowedFHIRStores()) == 0 {
filtered = append(filtered, store)
continue
}
@@ -157,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
continue
}
parts := strings.Split(store.Name, "/")
if _, ok := t.AllowedStores[parts[len(parts)-1]]; ok {
if _, ok := source.AllowedFHIRStores()[parts[len(parts)-1]]; ok {
filtered = append(filtered, store)
}
}
@@ -180,10 +154,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -61,11 +61,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -91,7 +86,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
params := parameters.Parameters{
@@ -107,17 +102,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedDICOMStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -127,15 +115,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -143,19 +125,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := t.Service
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
@@ -177,7 +164,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if !ok {
return nil, fmt.Errorf("invalid '%s' parameter; expected an integer", frameNumberKey)
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
dicomWebPath := fmt.Sprintf("studies/%s/series/%s/instances/%s/frames/%d/rendered", study, series, sop, frame)
call := svc.Projects.Locations.Datasets.DicomStores.Studies.Series.Instances.Frames.RetrieveRendered(name, dicomWebPath)
call.Header().Set("Accept", "image/jpeg")
@@ -214,10 +201,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -68,11 +68,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -98,7 +93,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
params := parameters.Parameters{
@@ -121,17 +116,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedDICOMStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -141,15 +129,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -157,19 +139,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := t.Service
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
@@ -204,7 +191,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to search dicom instances: %w", err)
@@ -244,10 +231,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -65,11 +65,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
var _ compatibleSource = &healthcareds.Source{}
var compatibleSources = [...]string{healthcareds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
@@ -95,7 +90,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
params := parameters.Parameters{
@@ -117,17 +112,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: params,
Project: s.Project(),
Region: s.Region(),
Dataset: s.DatasetID(),
AllowedStores: s.AllowedDICOMStores(),
UseClientOAuth: s.UseClientAuthorization(),
ServiceCreator: s.ServiceCreator(),
Service: s.Service(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Config: cfg,
Parameters: params,
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -137,15 +125,9 @@ var _ tools.Tool = Tool{}
type Tool struct {
Config
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters parameters.Parameters `yaml:"parameters"`
Project, Region, Dataset string
AllowedStores map[string]struct{}
Service *healthcare.Service
ServiceCreator healthcareds.HealthcareServiceCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters `yaml:"parameters"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
@@ -153,19 +135,24 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
svc := t.Service
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if t.UseClientOAuth {
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = t.ServiceCreator(tokenStr)
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
@@ -187,7 +174,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForSeries(name, dicomWebPath).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to search dicom series: %w", err)
@@ -227,10 +214,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.UseClientOAuth
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

Some files were not shown because too many files have changed in this diff Show More