Compare commits

..

23 Commits

Author SHA1 Message Date
rahulpinto19
c835af431a test 2026-01-12 06:14:32 +00:00
manuka rahul
bde46e1401 Merge branch 'main' into integration-test-cleanup-testing 2026-01-12 05:07:13 +00:00
Shobhit Singh
49d255254c feat(bigquery): make maximum rows returned from queries configurable (#2262)
This change allows the agent developer to control the maxium number of
rows returned from tools running BigQuery SQL query. Using this feature
the agent developer could limit how large output is presented to LLM in
an agentic user journey.

## Description

> Should include a concise description of the changes (bug or feature),
it's
> impact, along with a summary of the solution

## PR Checklist

> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [x] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [ ] Make sure to open an issue
https://github.com/googleapis/genai-toolbox/issues/2261
  before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)
- [ ] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #2261 2261
2026-01-12 03:22:13 +00:00
Yuan Teoh
eedd0ab67f docs: add issue and pr triaging and SLO (#2257)
## Description

update docs to reflect triaging workflow and SLO

## PR Checklist

> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [x] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [x] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
  before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)
- [x] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #<issue_number_goes_here>
2026-01-12 03:22:12 +00:00
rahulpinto19
65b21960fa test 2026-01-09 08:42:14 +00:00
rahulpinto19
6ff567df47 test 2026-01-09 07:03:44 +00:00
rahulpinto19
cc3c52731d test 2026-01-09 06:53:10 +00:00
rahulpinto19
87ff9150d2 printign in common.go 2026-01-09 06:45:27 +00:00
manuka rahul
81b42ac973 Merge branch 'main' into integration-test-cleanup-testing 2026-01-09 06:36:02 +00:00
release-please[bot]
be7560f606 chore(main): release 0.25.0 (#2218)
🤖 I have created a release *beep* *boop*
---


##
[0.25.0](https://github.com/googleapis/genai-toolbox/compare/v0.24.0...v0.25.0)
(2026-01-08)


### Features

* Add `embeddingModel` support
([#2121](https://github.com/googleapis/genai-toolbox/issues/2121))
([9c62f31](9c62f313ff))
* Add `allowed-hosts` flag
([#2254](https://github.com/googleapis/genai-toolbox/issues/2254))
([17b41f6](17b41f6453))
* Add parameter default value to manifest
([#2264](https://github.com/googleapis/genai-toolbox/issues/2264))
([9d1feca](9d1feca108))
* **snowflake:** Add Snowflake Source and Tools
([#858](https://github.com/googleapis/genai-toolbox/issues/858))
([b706b5b](b706b5bc68))
* **prebuilt/cloud-sql-mysql:** Update CSQL MySQL prebuilt tools to use
IAM ([#2202](https://github.com/googleapis/genai-toolbox/issues/2202))
([731a32e](731a32e536))
* **sources/bigquery:** Make credentials scope configurable
([#2210](https://github.com/googleapis/genai-toolbox/issues/2210))
([a450600](a4506009b9))
* **sources/trino:** Add ssl verification options and fix docs example
([#2155](https://github.com/googleapis/genai-toolbox/issues/2155))
([4a4cf1e](4a4cf1e712))
* **tools/looker:** Add ability to set destination folder with
`make_look` and `make_dashboard`.
([#2245](https://github.com/googleapis/genai-toolbox/issues/2245))
([eb79339](eb793398cd))
* **tools/postgressql:** Add tool to list store procedure
([#2156](https://github.com/googleapis/genai-toolbox/issues/2156))
([cf0fc51](cf0fc515b5))
* **tools/postgressql:** Add Parameter `embeddedBy` config support
([#2151](https://github.com/googleapis/genai-toolbox/issues/2151))
([17b70cc](17b70ccaa7))


### Bug Fixes

* **server:** Add `embeddingModel` config initialization
([#2281](https://github.com/googleapis/genai-toolbox/issues/2281))
([a779975](a7799757c9))
* **sources/cloudgda:** Add import for cloudgda source
([#2217](https://github.com/googleapis/genai-toolbox/issues/2217))
([7daa411](7daa4111f4))
* **tools/alloydb-wait-for-operation:** Fix connection message
generation
([#2228](https://github.com/googleapis/genai-toolbox/issues/2228))
([7053fbb](7053fbb195))
* **tools/alloydbainl:** Only add psv when NL Config Param is defined
([#2265](https://github.com/googleapis/genai-toolbox/issues/2265))
([ef8f3b0](ef8f3b02f2))
* **tools/looker:** Looker client OAuth nil pointer error
([#2231](https://github.com/googleapis/genai-toolbox/issues/2231))
([268700b](268700bdbf))

---
This PR was generated with [Release
Please](https://github.com/googleapis/release-please). See
[documentation](https://github.com/googleapis/release-please#release-please).

---------

Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com>
Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
2026-01-09 05:54:41 +00:00
Yuan Teoh
edaa1c23cf chore: update hugo for release (#2282)
Update hugo version for release v0.25.0
2026-01-09 05:54:41 +00:00
Yuan Teoh
0cfac2984e fix(tools/alloydbainl): only add psv when NL Config Param is defined (#2265)
## Description

PSV should only be required when when it is needed. Currently, we
require psv even whenever user uses AlloyDB AI NL tool. This is due to
the statement that we use to execute nl query.

This PR modified the statement query to only utilize `param_names` and
`param_values` when needed.

Manually tested with a db that does not have psv installed.

🛠️ Fixes #1970
2026-01-09 05:54:41 +00:00
Yuan Teoh
0b7ffa4364 chore: update mcp registry schema version (#2266) 2026-01-09 05:54:41 +00:00
Yuan Teoh
ca4a12b5f3 feat: add default value to manifest (#2264)
## Description

Add default value to manifest (for both native endpoint and mcp
endpoint).

## PR Checklist

> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [x] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [x] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
  before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)
- [x] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #1602
2026-01-09 05:54:41 +00:00
Yuan Teoh
7d5af5f001 feat: add allowed-hosts flag (#2254)
## Description

Previously added `allowed-origins` (for CORs) is not sufficient for
preventing DNS rebinding attacks. We'll have to check host headers.

To test, run Toolbox with the following:
```
go run . --allowed-hosts=127.0.0.1:5000
```

Test with the following:
```
// curl successfully
curl -H "Host: 127.0.0.1:5000" http://127.0.0.1:5000

// will show Invalid Host Header error
curl -H "Host: attacker:5000" http://127.0.0.1:5000
```

## PR Checklist

> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [ ] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [ ] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
  before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [ ] Ensure the tests and linter pass
- [ ] Code coverage does not decrease (if any source code was changed)
- [ ] Appropriate docs were updated (if necessary)
- [ ] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #<issue_number_goes_here>
2026-01-09 05:54:41 +00:00
Wenxin Du
2bcfca0981 fix(server): Add embeddingModel config initialization (#2281)
Embedding Models were only loaded in hot reload because it was not
initialized properly.
2026-01-09 05:54:40 +00:00
Twisha Bansal
3932efbd4b docs: link medium blogs to toolbox docsite (#2269)
## Description

Adds a section in the navbar that links to the toolbox medium blog: 
<img width="492" height="822" alt="87F2yTQdcbpMHs3"
src="https://github.com/user-attachments/assets/74d8b552-1e8f-449c-8b09-4f86218d2817"
/>


## PR Checklist

> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [ ] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [ ] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
  before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [ ] Ensure the tests and linter pass
- [ ] Code coverage does not decrease (if any source code was changed)
- [ ] Appropriate docs were updated (if necessary)
- [ ] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #<issue_number_goes_here>

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-09 05:54:40 +00:00
rahulpinto19
1bfe394222 test 2026-01-08 10:07:52 +00:00
rahulpinto19
96904e28e9 test 2026-01-08 10:07:52 +00:00
rahulpinto19
c3c2cec9ab test 2026-01-08 10:07:51 +00:00
rahulpinto19
1a39de7617 test1 2026-01-08 10:07:51 +00:00
rahulpinto19
edf8377c57 test 2026-01-08 10:07:51 +00:00
rahulpinto19
fd4250d0db cleanup updated 2026-01-08 10:07:51 +00:00
47 changed files with 1686 additions and 1377 deletions

View File

@@ -272,7 +272,7 @@ To run Toolbox from binary:
To run the server after pulling the [container image](#installing-the-server):
```sh
export VERSION=0.24.0 # Use the version you pulled
export VERSION=0.11.0 # Use the version you pulled
docker run -p 5000:5000 \
-v $(pwd)/tools.yaml:/app/tools.yaml \
us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION \

View File

@@ -183,13 +183,6 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
embeddingModels := resourceMgr.GetEmbeddingModelMap()
params, err = tool.EmbedParams(ctx, params, embeddingModels)
if err != nil {
err = fmt.Errorf("error embedding parameters: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil {

View File

@@ -16,12 +16,8 @@ package cloudhealthcare
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -259,299 +255,3 @@ func (s *Source) IsDICOMStoreAllowed(storeID string) bool {
func (s *Source) UseClientAuthorization() bool {
return s.UseClientOAuth
}
func parseResults(resp *http.Response) (any, error) {
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
var jsonMap map[string]interface{}
if err := json.Unmarshal(respBytes, &jsonMap); err != nil {
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
}
return jsonMap, nil
}
func (s *Source) getService(tokenStr string) (*healthcare.Service, error) {
svc := s.Service()
var err error
// Initialize new service if using user OAuth token
if s.UseClientAuthorization() {
svc, err = s.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
return svc, nil
}
func (s *Source) FHIRFetchPage(ctx context.Context, url, tokenStr string) (any, error) {
var httpClient *http.Client
if s.UseClientAuthorization() {
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tokenStr})
httpClient = oauth2.NewClient(ctx, ts)
} else {
// The source.Service() object holds a client with the default credentials.
// However, the client is not exported, so we have to create a new one.
var err error
httpClient, err = google.DefaultClient(ctx, healthcare.CloudHealthcareScope)
if err != nil {
return nil, fmt.Errorf("failed to create default http client: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create http request: %w", err)
}
req.Header.Set("Accept", "application/fhir+json;charset=utf-8")
resp, err := httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to get fhir page from %q: %w", url, err)
}
defer resp.Body.Close()
return parseResults(resp)
}
func (s *Source) FHIRPatientEverything(storeID, patientID, tokenStr string, opts []googleapi.CallOption) (any, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", s.Project(), s.Region(), s.DatasetID(), storeID, patientID)
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.PatientEverything(name).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to call patient everything for %q: %w", name, err)
}
defer resp.Body.Close()
return parseResults(resp)
}
func (s *Source) FHIRPatientSearch(storeID, tokenStr string, opts []googleapi.CallOption) (any, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.SearchType(name, "Patient", &healthcare.SearchResourcesRequest{ResourceType: "Patient"}).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to search patient resources: %w", err)
}
defer resp.Body.Close()
return parseResults(resp)
}
func (s *Source) GetDataset(tokenStr string) (*healthcare.Dataset, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
}
return dataset, nil
}
func (s *Source) GetFHIRResource(storeID, resType, resID, tokenStr string) (any, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", s.Project(), s.Region(), s.DatasetID(), storeID, resType, resID)
call := svc.Projects.Locations.Datasets.FhirStores.Fhir.Read(name)
call.Header().Set("Content-Type", "application/fhir+json;charset=utf-8")
resp, err := call.Do()
if err != nil {
return nil, fmt.Errorf("failed to get fhir resource %q: %w", name, err)
}
defer resp.Body.Close()
return parseResults(resp)
}
func (s *Source) GetDICOMStore(storeID, tokenStr string) (*healthcare.DicomStore, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.DicomStores.Get(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get DICOM store %q: %w", storeName, err)
}
return store, nil
}
func (s *Source) GetFHIRStore(storeID, tokenStr string) (*healthcare.FhirStore, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.FhirStores.Get(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get FHIR store %q: %w", storeName, err)
}
return store, nil
}
func (s *Source) GetDICOMStoreMetrics(storeID, tokenStr string) (*healthcare.DicomStoreMetrics, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.DicomStores.GetDICOMStoreMetrics(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get metrics for DICOM store %q: %w", storeName, err)
}
return store, nil
}
func (s *Source) GetFHIRStoreMetrics(storeID, tokenStr string) (*healthcare.FhirStoreMetrics, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.FhirStores.GetFHIRStoreMetrics(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get metrics for FHIR store %q: %w", storeName, err)
}
return store, nil
}
func (s *Source) ListDICOMStores(tokenStr string) ([]*healthcare.DicomStore, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
}
var filtered []*healthcare.DicomStore
for _, store := range stores.DicomStores {
if len(s.AllowedDICOMStores()) == 0 {
filtered = append(filtered, store)
continue
}
if len(store.Name) == 0 {
continue
}
parts := strings.Split(store.Name, "/")
if _, ok := s.AllowedDICOMStores()[parts[len(parts)-1]]; ok {
filtered = append(filtered, store)
}
}
return filtered, nil
}
func (s *Source) ListFHIRStores(tokenStr string) ([]*healthcare.FhirStore, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
}
var filtered []*healthcare.FhirStore
for _, store := range stores.FhirStores {
if len(s.AllowedFHIRStores()) == 0 {
filtered = append(filtered, store)
continue
}
if len(store.Name) == 0 {
continue
}
parts := strings.Split(store.Name, "/")
if _, ok := s.AllowedFHIRStores()[parts[len(parts)-1]]; ok {
filtered = append(filtered, store)
}
}
return filtered, nil
}
func (s *Source) RetrieveRenderedDICOMInstance(storeID, study, series, sop string, frame int, tokenStr string) (any, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
dicomWebPath := fmt.Sprintf("studies/%s/series/%s/instances/%s/frames/%d/rendered", study, series, sop, frame)
call := svc.Projects.Locations.Datasets.DicomStores.Studies.Series.Instances.Frames.RetrieveRendered(name, dicomWebPath)
call.Header().Set("Accept", "image/jpeg")
resp, err := call.Do()
if err != nil {
return nil, fmt.Errorf("unable to retrieve dicom instance rendered image: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("RetrieveRendered: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
base64String := base64.StdEncoding.EncodeToString(respBytes)
return base64String, nil
}
func (s *Source) SearchDICOM(toolKind, storeID, dicomWebPath, tokenStr string, opts []googleapi.CallOption) (any, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
var resp *http.Response
switch toolKind {
case "cloud-healthcare-search-dicom-instances":
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...)
case "cloud-healthcare-search-dicom-series":
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForSeries(name, dicomWebPath).Do(opts...)
case "cloud-healthcare-search-dicom-studies":
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForStudies(name, dicomWebPath).Do(opts...)
default:
return nil, fmt.Errorf("incompatible tool kind: %s", toolKind)
}
if err != nil {
return nil, fmt.Errorf("failed to search dicom series: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
if len(respBytes) == 0 {
return []interface{}{}, nil
}
var result []interface{}
if err := json.Unmarshal(respBytes, &result); err != nil {
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
}
return result, nil
}

View File

@@ -19,8 +19,6 @@ import (
"fmt"
dataplexapi "cloud.google.com/go/dataplex/apiv1"
"cloud.google.com/go/dataplex/apiv1/dataplexpb"
"github.com/cenkalti/backoff/v5"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
@@ -123,101 +121,3 @@ func initDataplexConnection(
}
return client, nil
}
func (s *Source) LookupEntry(ctx context.Context, name string, view int, aspectTypes []string, entry string) (*dataplexpb.Entry, error) {
viewMap := map[int]dataplexpb.EntryView{
1: dataplexpb.EntryView_BASIC,
2: dataplexpb.EntryView_FULL,
3: dataplexpb.EntryView_CUSTOM,
4: dataplexpb.EntryView_ALL,
}
req := &dataplexpb.LookupEntryRequest{
Name: name,
View: viewMap[view],
AspectTypes: aspectTypes,
Entry: entry,
}
result, err := s.CatalogClient().LookupEntry(ctx, req)
if err != nil {
return nil, err
}
return result, nil
}
func (s *Source) searchRequest(ctx context.Context, query string, pageSize int, orderBy string) (*dataplexapi.SearchEntriesResultIterator, error) {
// Create SearchEntriesRequest with the provided parameters
req := &dataplexpb.SearchEntriesRequest{
Query: query,
Name: fmt.Sprintf("projects/%s/locations/global", s.ProjectID()),
PageSize: int32(pageSize),
OrderBy: orderBy,
SemanticSearch: true,
}
// Perform the search using the CatalogClient - this will return an iterator
it := s.CatalogClient().SearchEntries(ctx, req)
if it == nil {
return nil, fmt.Errorf("failed to create search entries iterator for project %q", s.ProjectID())
}
return it, nil
}
func (s *Source) SearchAspectTypes(ctx context.Context, query string, pageSize int, orderBy string) ([]*dataplexpb.AspectType, error) {
q := query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype"
it, err := s.searchRequest(ctx, q, pageSize, orderBy)
if err != nil {
return nil, err
}
// Iterate through the search results and call GetAspectType for each result using the resource name
var results []*dataplexpb.AspectType
for {
entry, err := it.Next()
if err != nil {
break
}
// Create an instance of exponential backoff with default values for retrying GetAspectType calls
// InitialInterval, RandomizationFactor, Multiplier, MaxInterval = 500 ms, 0.5, 1.5, 60 s
getAspectBackOff := backoff.NewExponentialBackOff()
resourceName := entry.DataplexEntry.GetEntrySource().Resource
getAspectTypeReq := &dataplexpb.GetAspectTypeRequest{
Name: resourceName,
}
operation := func() (*dataplexpb.AspectType, error) {
aspectType, err := s.CatalogClient().GetAspectType(ctx, getAspectTypeReq)
if err != nil {
return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err)
}
return aspectType, nil
}
// Retry the GetAspectType operation with exponential backoff
aspectType, err := backoff.Retry(ctx, operation, backoff.WithBackOff(getAspectBackOff))
if err != nil {
return nil, fmt.Errorf("failed to get aspect type after retries for entry %q: %w", resourceName, err)
}
results = append(results, aspectType)
}
return results, nil
}
func (s *Source) SearchEntries(ctx context.Context, query string, pageSize int, orderBy string) ([]*dataplexpb.SearchEntriesResult, error) {
it, err := s.searchRequest(ctx, query, pageSize, orderBy)
if err != nil {
return nil, err
}
var results []*dataplexpb.SearchEntriesResult
for {
entry, err := it.Next()
if err != nil {
break
}
results = append(results, entry)
}
return results, nil
}

View File

@@ -16,10 +16,7 @@ package firestore
import (
"context"
"encoding/base64"
"fmt"
"strings"
"time"
"cloud.google.com/go/firestore"
"github.com/goccy/go-yaml"
@@ -28,7 +25,6 @@ import (
"go.opentelemetry.io/otel/trace"
"google.golang.org/api/firebaserules/v1"
"google.golang.org/api/option"
"google.golang.org/genproto/googleapis/type/latlng"
)
const SourceKind string = "firestore"
@@ -117,476 +113,6 @@ func (s *Source) GetDatabaseId() string {
return s.Database
}
// FirestoreValueToJSON converts a Firestore value to a simplified JSON representation
// This removes type information and returns plain values
func FirestoreValueToJSON(value any) any {
if value == nil {
return nil
}
switch v := value.(type) {
case time.Time:
return v.Format(time.RFC3339Nano)
case *latlng.LatLng:
return map[string]any{
"latitude": v.Latitude,
"longitude": v.Longitude,
}
case []byte:
return base64.StdEncoding.EncodeToString(v)
case []any:
result := make([]any, len(v))
for i, item := range v {
result[i] = FirestoreValueToJSON(item)
}
return result
case map[string]any:
result := make(map[string]any)
for k, val := range v {
result[k] = FirestoreValueToJSON(val)
}
return result
case *firestore.DocumentRef:
return v.Path
default:
return value
}
}
// BuildQuery constructs the Firestore query from parameters
func (s *Source) BuildQuery(collectionPath string, filter firestore.EntityFilter, selectFields []string, field string, direction firestore.Direction, limit int, analyzeQuery bool) (*firestore.Query, error) {
collection := s.FirestoreClient().Collection(collectionPath)
query := collection.Query
// Process and apply filters if template is provided
if filter != nil {
query = query.WhereEntity(filter)
}
if len(selectFields) > 0 {
query = query.Select(selectFields...)
}
if field != "" {
query = query.OrderBy(field, direction)
}
query = query.Limit(limit)
// Apply analyze options if enabled
if analyzeQuery {
query = query.WithRunOptions(firestore.ExplainOptions{
Analyze: true,
})
}
return &query, nil
}
// QueryResult represents a document result from the query
type QueryResult struct {
ID string `json:"id"`
Path string `json:"path"`
Data map[string]any `json:"data"`
CreateTime any `json:"createTime,omitempty"`
UpdateTime any `json:"updateTime,omitempty"`
ReadTime any `json:"readTime,omitempty"`
}
// QueryResponse represents the full response including optional metrics
type QueryResponse struct {
Documents []QueryResult `json:"documents"`
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
}
// ExecuteQuery runs the query and formats the results
func (s *Source) ExecuteQuery(ctx context.Context, query *firestore.Query, analyzeQuery bool) (any, error) {
docIterator := query.Documents(ctx)
docs, err := docIterator.GetAll()
if err != nil {
return nil, fmt.Errorf("failed to execute query: %w", err)
}
// Convert results to structured format
results := make([]QueryResult, len(docs))
for i, doc := range docs {
results[i] = QueryResult{
ID: doc.Ref.ID,
Path: doc.Ref.Path,
Data: doc.Data(),
CreateTime: doc.CreateTime,
UpdateTime: doc.UpdateTime,
ReadTime: doc.ReadTime,
}
}
// Return with explain metrics if requested
if analyzeQuery {
explainMetrics, err := getExplainMetrics(docIterator)
if err == nil && explainMetrics != nil {
response := QueryResponse{
Documents: results,
ExplainMetrics: explainMetrics,
}
return response, nil
}
}
return results, nil
}
// getExplainMetrics extracts explain metrics from the query iterator
func getExplainMetrics(docIterator *firestore.DocumentIterator) (map[string]any, error) {
explainMetrics, err := docIterator.ExplainMetrics()
if err != nil || explainMetrics == nil {
return nil, err
}
metricsData := make(map[string]any)
// Add plan summary if available
if explainMetrics.PlanSummary != nil {
planSummary := make(map[string]any)
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
metricsData["planSummary"] = planSummary
}
// Add execution stats if available
if explainMetrics.ExecutionStats != nil {
executionStats := make(map[string]any)
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
}
if explainMetrics.ExecutionStats.DebugStats != nil {
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
}
metricsData["executionStats"] = executionStats
}
return metricsData, nil
}
func (s *Source) GetDocuments(ctx context.Context, documentPaths []string) ([]any, error) {
// Create document references from paths
docRefs := make([]*firestore.DocumentRef, len(documentPaths))
for i, path := range documentPaths {
docRefs[i] = s.FirestoreClient().Doc(path)
}
// Get all documents
snapshots, err := s.FirestoreClient().GetAll(ctx, docRefs)
if err != nil {
return nil, fmt.Errorf("failed to get documents: %w", err)
}
// Convert snapshots to response data
results := make([]any, len(snapshots))
for i, snapshot := range snapshots {
docData := make(map[string]any)
docData["path"] = documentPaths[i]
docData["exists"] = snapshot.Exists()
if snapshot.Exists() {
docData["data"] = snapshot.Data()
docData["createTime"] = snapshot.CreateTime
docData["updateTime"] = snapshot.UpdateTime
docData["readTime"] = snapshot.ReadTime
}
results[i] = docData
}
return results, nil
}
func (s *Source) AddDocuments(ctx context.Context, collectionPath string, documentData any, returnData bool) (map[string]any, error) {
// Get the collection reference
collection := s.FirestoreClient().Collection(collectionPath)
// Add the document to the collection
docRef, writeResult, err := collection.Add(ctx, documentData)
if err != nil {
return nil, fmt.Errorf("failed to add document: %w", err)
}
// Build the response
response := map[string]any{
"documentPath": docRef.Path,
"createTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
}
// Add document data if requested
if returnData {
// Fetch the updated document to return the current state
snapshot, err := docRef.Get(ctx)
if err != nil {
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
}
// Convert the document data back to simple JSON format
simplifiedData := FirestoreValueToJSON(snapshot.Data())
response["documentData"] = simplifiedData
}
return response, nil
}
func (s *Source) UpdateDocument(ctx context.Context, documentPath string, updates []firestore.Update, documentData any, returnData bool) (map[string]any, error) {
// Get the document reference
docRef := s.FirestoreClient().Doc(documentPath)
// Prepare update data
var writeResult *firestore.WriteResult
var writeErr error
if len(updates) > 0 {
writeResult, writeErr = docRef.Update(ctx, updates)
} else {
writeResult, writeErr = docRef.Set(ctx, documentData, firestore.MergeAll)
}
if writeErr != nil {
return nil, fmt.Errorf("failed to update document: %w", writeErr)
}
// Build the response
response := map[string]any{
"documentPath": docRef.Path,
"updateTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
}
// Add document data if requested
if returnData {
// Fetch the updated document to return the current state
snapshot, err := docRef.Get(ctx)
if err != nil {
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
}
// Convert the document data to simple JSON format
simplifiedData := FirestoreValueToJSON(snapshot.Data())
response["documentData"] = simplifiedData
}
return response, nil
}
func (s *Source) DeleteDocuments(ctx context.Context, documentPaths []string) ([]any, error) {
// Create a BulkWriter to handle multiple deletions efficiently
bulkWriter := s.FirestoreClient().BulkWriter(ctx)
// Keep track of jobs for each document
jobs := make([]*firestore.BulkWriterJob, len(documentPaths))
// Add all delete operations to the BulkWriter
for i, path := range documentPaths {
docRef := s.FirestoreClient().Doc(path)
job, err := bulkWriter.Delete(docRef)
if err != nil {
return nil, fmt.Errorf("failed to add delete operation for document %q: %w", path, err)
}
jobs[i] = job
}
// End the BulkWriter to execute all operations
bulkWriter.End()
// Collect results
results := make([]any, len(documentPaths))
for i, job := range jobs {
docData := make(map[string]any)
docData["path"] = documentPaths[i]
// Wait for the job to complete and get the result
_, err := job.Results()
if err != nil {
docData["success"] = false
docData["error"] = err.Error()
} else {
docData["success"] = true
}
results[i] = docData
}
return results, nil
}
func (s *Source) ListCollections(ctx context.Context, parentPath string) ([]any, error) {
var collectionRefs []*firestore.CollectionRef
var err error
if parentPath != "" {
// List subcollections of the specified document
docRef := s.FirestoreClient().Doc(parentPath)
collectionRefs, err = docRef.Collections(ctx).GetAll()
if err != nil {
return nil, fmt.Errorf("failed to list subcollections of document %q: %w", parentPath, err)
}
} else {
// List root collections
collectionRefs, err = s.FirestoreClient().Collections(ctx).GetAll()
if err != nil {
return nil, fmt.Errorf("failed to list root collections: %w", err)
}
}
// Convert collection references to response data
results := make([]any, len(collectionRefs))
for i, collRef := range collectionRefs {
collData := make(map[string]any)
collData["id"] = collRef.ID
collData["path"] = collRef.Path
// If this is a subcollection, include parent information
if collRef.Parent != nil {
collData["parent"] = collRef.Parent.Path
}
results[i] = collData
}
return results, nil
}
func (s *Source) GetRules(ctx context.Context) (any, error) {
// Get the latest release for Firestore
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", s.GetProjectId(), s.GetDatabaseId())
release, err := s.FirebaseRulesClient().Projects.Releases.Get(releaseName).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to get latest Firestore release: %w", err)
}
if release.RulesetName == "" {
return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", s.GetProjectId(), s.GetDatabaseId())
}
// Get the ruleset content
ruleset, err := s.FirebaseRulesClient().Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to get ruleset content: %w", err)
}
if ruleset.Source == nil || len(ruleset.Source.Files) == 0 {
return nil, fmt.Errorf("no rules files found in ruleset")
}
return ruleset, nil
}
// SourcePosition represents the location of an issue in the source
type SourcePosition struct {
FileName string `json:"fileName,omitempty"`
Line int64 `json:"line"` // 1-based
Column int64 `json:"column"` // 1-based
CurrentOffset int64 `json:"currentOffset"` // 0-based, inclusive start
EndOffset int64 `json:"endOffset"` // 0-based, exclusive end
}
// Issue represents a validation issue in the rules
type Issue struct {
SourcePosition SourcePosition `json:"sourcePosition"`
Description string `json:"description"`
Severity string `json:"severity"`
}
// ValidationResult represents the result of rules validation
type ValidationResult struct {
Valid bool `json:"valid"`
IssueCount int `json:"issueCount"`
FormattedIssues string `json:"formattedIssues,omitempty"`
RawIssues []Issue `json:"rawIssues,omitempty"`
}
func (s *Source) ValidateRules(ctx context.Context, sourceParam string) (any, error) {
// Create test request
testRequest := &firebaserules.TestRulesetRequest{
Source: &firebaserules.Source{
Files: []*firebaserules.File{
{
Name: "firestore.rules",
Content: sourceParam,
},
},
},
// We don't need test cases for validation only
TestSuite: &firebaserules.TestSuite{
TestCases: []*firebaserules.TestCase{},
},
}
// Call the test API
projectName := fmt.Sprintf("projects/%s", s.GetProjectId())
response, err := s.FirebaseRulesClient().Projects.Test(projectName, testRequest).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to validate rules: %w", err)
}
// Process the response
if len(response.Issues) == 0 {
return ValidationResult{
Valid: true,
IssueCount: 0,
FormattedIssues: "✓ No errors detected. Rules are valid.",
}, nil
}
// Convert issues to our format
issues := make([]Issue, len(response.Issues))
for i, issue := range response.Issues {
issues[i] = Issue{
Description: issue.Description,
Severity: issue.Severity,
SourcePosition: SourcePosition{
FileName: issue.SourcePosition.FileName,
Line: issue.SourcePosition.Line,
Column: issue.SourcePosition.Column,
CurrentOffset: issue.SourcePosition.CurrentOffset,
EndOffset: issue.SourcePosition.EndOffset,
},
}
}
// Format issues
sourceLines := strings.Split(sourceParam, "\n")
var formattedOutput []string
formattedOutput = append(formattedOutput, fmt.Sprintf("Found %d issue(s) in rules source:\n", len(issues)))
for _, issue := range issues {
issueString := fmt.Sprintf("%s: %s [Ln %d, Col %d]",
issue.Severity,
issue.Description,
issue.SourcePosition.Line,
issue.SourcePosition.Column)
if issue.SourcePosition.Line > 0 {
lineIndex := int(issue.SourcePosition.Line - 1) // 0-based index
if lineIndex >= 0 && lineIndex < len(sourceLines) {
errorLine := sourceLines[lineIndex]
issueString += fmt.Sprintf("\n```\n%s", errorLine)
// Add carets if we have column and offset information
if issue.SourcePosition.Column > 0 &&
issue.SourcePosition.CurrentOffset >= 0 &&
issue.SourcePosition.EndOffset > issue.SourcePosition.CurrentOffset {
startColumn := int(issue.SourcePosition.Column - 1) // 0-based
errorTokenLength := int(issue.SourcePosition.EndOffset - issue.SourcePosition.CurrentOffset)
if startColumn >= 0 && errorTokenLength > 0 && startColumn <= len(errorLine) {
padding := strings.Repeat(" ", startColumn)
carets := strings.Repeat("^", errorTokenLength)
issueString += fmt.Sprintf("\n%s%s", padding, carets)
}
}
issueString += "\n```"
}
}
formattedOutput = append(formattedOutput, issueString)
}
formattedIssues := strings.Join(formattedOutput, "\n\n")
return ValidationResult{
Valid: false,
IssueCount: len(issues),
FormattedIssues: formattedIssues,
RawIssues: issues,
}, nil
}
func initFirestoreConnection(
ctx context.Context,
tracer trace.Tracer,

View File

@@ -16,7 +16,6 @@ package firestore_test
import (
"testing"
"time"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
@@ -129,37 +128,3 @@ func TestFailParseFromYamlFirestore(t *testing.T) {
})
}
}
func TestFirestoreValueToJSON_RoundTrip(t *testing.T) {
// Test round-trip conversion
original := map[string]any{
"name": "Test",
"count": int64(42),
"price": 19.99,
"active": true,
"tags": []any{"tag1", "tag2"},
"metadata": map[string]any{
"created": time.Now(),
},
"nullField": nil,
}
// Convert to JSON representation
jsonRepresentation := firestore.FirestoreValueToJSON(original)
// Verify types are simplified
jsonMap, ok := jsonRepresentation.(map[string]any)
if !ok {
t.Fatalf("Expected map, got %T", jsonRepresentation)
}
// Time should be converted to string
metadata, ok := jsonMap["metadata"].(map[string]any)
if !ok {
t.Fatalf("metadata should be a map, got %T", jsonMap["metadata"])
}
_, ok = metadata["created"].(string)
if !ok {
t.Errorf("created should be a string, got %T", metadata["created"])
}
}

View File

@@ -16,9 +16,7 @@ package http
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"
@@ -145,28 +143,3 @@ func (s *Source) HttpQueryParams() map[string]string {
func (s *Source) Client() *http.Client {
return s.client
}
func (s *Source) RunRequest(req *http.Request) (any, error) {
// Make request and fetch response
resp, err := s.Client().Do(req)
if err != nil {
return nil, fmt.Errorf("error making HTTP request: %s", err)
}
defer resp.Body.Close()
var body []byte
body, err = io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode > 299 {
return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body))
}
var data any
if err = json.Unmarshal(body, &data); err != nil {
// if unable to unmarshal data, return result as string.
return string(body), nil
}
return data, nil
}

View File

@@ -16,21 +16,15 @@ package serverlessspark
import (
"context"
"encoding/json"
"fmt"
"time"
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
longrunning "cloud.google.com/go/longrunning/autogen"
"cloud.google.com/go/longrunning/autogen/longrunningpb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
"google.golang.org/protobuf/encoding/protojson"
)
const SourceKind string = "serverless-spark"
@@ -127,168 +121,3 @@ func (s *Source) Close() error {
}
return nil
}
func (s *Source) CancelOperation(ctx context.Context, operation string) (any, error) {
req := &longrunningpb.CancelOperationRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", s.GetProject(), s.GetLocation(), operation),
}
client, err := s.GetOperationsClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get operations client: %w", err)
}
err = client.CancelOperation(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to cancel operation: %w", err)
}
return fmt.Sprintf("Cancelled [%s].", operation), nil
}
func (s *Source) CreateBatch(ctx context.Context, batch *dataprocpb.Batch) (map[string]any, error) {
req := &dataprocpb.CreateBatchRequest{
Parent: fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation()),
Batch: batch,
}
client := s.GetBatchControllerClient()
op, err := client.CreateBatch(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to create batch: %w", err)
}
meta, err := op.Metadata()
if err != nil {
return nil, fmt.Errorf("failed to get create batch op metadata: %w", err)
}
projectID, location, batchID, err := ExtractBatchDetails(meta.GetBatch())
if err != nil {
return nil, fmt.Errorf("error extracting batch details from name %q: %v", meta.GetBatch(), err)
}
consoleUrl := BatchConsoleURL(projectID, location, batchID)
logsUrl := BatchLogsURL(projectID, location, batchID, meta.GetCreateTime().AsTime(), time.Time{})
wrappedResult := map[string]any{
"opMetadata": meta,
"consoleUrl": consoleUrl,
"logsUrl": logsUrl,
}
return wrappedResult, nil
}
// ListBatchesResponse is the response from the list batches API.
type ListBatchesResponse struct {
Batches []Batch `json:"batches"`
NextPageToken string `json:"nextPageToken"`
}
// Batch represents a single batch job.
type Batch struct {
Name string `json:"name"`
UUID string `json:"uuid"`
State string `json:"state"`
Creator string `json:"creator"`
CreateTime string `json:"createTime"`
Operation string `json:"operation"`
ConsoleURL string `json:"consoleUrl"`
LogsURL string `json:"logsUrl"`
}
func (s *Source) ListBatches(ctx context.Context, ps *int, pt, filter string) (any, error) {
client := s.GetBatchControllerClient()
parent := fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation())
req := &dataprocpb.ListBatchesRequest{
Parent: parent,
OrderBy: "create_time desc",
}
if ps != nil {
req.PageSize = int32(*ps)
}
if pt != "" {
req.PageToken = pt
}
if filter != "" {
req.Filter = filter
}
it := client.ListBatches(ctx, req)
pager := iterator.NewPager(it, int(req.PageSize), req.PageToken)
var batchPbs []*dataprocpb.Batch
nextPageToken, err := pager.NextPage(&batchPbs)
if err != nil {
return nil, fmt.Errorf("failed to list batches: %w", err)
}
batches, err := ToBatches(batchPbs)
if err != nil {
return nil, err
}
return ListBatchesResponse{Batches: batches, NextPageToken: nextPageToken}, nil
}
// ToBatches converts a slice of protobuf Batch messages to a slice of Batch structs.
func ToBatches(batchPbs []*dataprocpb.Batch) ([]Batch, error) {
batches := make([]Batch, 0, len(batchPbs))
for _, batchPb := range batchPbs {
consoleUrl, err := BatchConsoleURLFromProto(batchPb)
if err != nil {
return nil, fmt.Errorf("error generating console url: %v", err)
}
logsUrl, err := BatchLogsURLFromProto(batchPb)
if err != nil {
return nil, fmt.Errorf("error generating logs url: %v", err)
}
batch := Batch{
Name: batchPb.Name,
UUID: batchPb.Uuid,
State: batchPb.State.Enum().String(),
Creator: batchPb.Creator,
CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339),
Operation: batchPb.Operation,
ConsoleURL: consoleUrl,
LogsURL: logsUrl,
}
batches = append(batches, batch)
}
return batches, nil
}
func (s *Source) GetBatch(ctx context.Context, name string) (map[string]any, error) {
client := s.GetBatchControllerClient()
req := &dataprocpb.GetBatchRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", s.GetProject(), s.GetLocation(), name),
}
batchPb, err := client.GetBatch(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to get batch: %w", err)
}
jsonBytes, err := protojson.Marshal(batchPb)
if err != nil {
return nil, fmt.Errorf("failed to marshal batch to JSON: %w", err)
}
var result map[string]any
if err := json.Unmarshal(jsonBytes, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal batch JSON: %w", err)
}
consoleUrl, err := BatchConsoleURLFromProto(batchPb)
if err != nil {
return nil, fmt.Errorf("error generating console url: %v", err)
}
logsUrl, err := BatchLogsURLFromProto(batchPb)
if err != nil {
return nil, fmt.Errorf("error generating logs url: %v", err)
}
wrappedResult := map[string]any{
"consoleUrl": consoleUrl,
"logsUrl": logsUrl,
"batch": result,
}
return wrappedResult, nil
}

View File

@@ -16,13 +16,22 @@ package fhirfetchpage
import (
"context"
"encoding/json"
"fmt"
"io"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/healthcare/v1"
"net/http"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
const kind string = "cloud-healthcare-fhir-fetch-page"
@@ -45,8 +54,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedFHIRStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
FHIRFetchPage(context.Context, string, string) (any, error)
}
type Config struct {
@@ -104,11 +118,48 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey)
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
var httpClient *http.Client
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tokenStr})
httpClient = oauth2.NewClient(ctx, ts)
} else {
// The source.Service() object holds a client with the default credentials.
// However, the client is not exported, so we have to create a new one.
var err error
httpClient, err = google.DefaultClient(ctx, healthcare.CloudHealthcareScope)
if err != nil {
return nil, fmt.Errorf("failed to create default http client: %w", err)
}
}
return source.FHIRFetchPage(ctx, url, tokenStr)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create http request: %w", err)
}
req.Header.Set("Accept", "application/fhir+json;charset=utf-8")
resp, err := httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to get fhir page from %q: %w", url, err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("read: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
var jsonMap map[string]interface{}
if err := json.Unmarshal([]byte(string(respBytes)), &jsonMap); err != nil {
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
}
return jsonMap, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -16,16 +16,20 @@ package fhirpatienteverything
import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/googleapi"
"google.golang.org/api/healthcare/v1"
)
const kind string = "cloud-healthcare-fhir-patient-everything"
@@ -50,9 +54,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedFHIRStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
FHIRPatientEverything(string, string, string, []googleapi.CallOption) (any, error)
}
type Config struct {
@@ -131,11 +139,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey)
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", source.Project(), source.Region(), source.DatasetID(), storeID, patientID)
var opts []googleapi.CallOption
if val, ok := params.AsMap()[typeFilterKey]; ok {
types, ok := val.([]any)
@@ -159,7 +176,25 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
opts = append(opts, googleapi.QueryParameter("_since", sinceStr))
}
}
return source.FHIRPatientEverything(storeID, patientID, tokenStr, opts)
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.PatientEverything(name).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to call patient everything for %q: %w", name, err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("patient-everything: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
var jsonMap map[string]interface{}
if err := json.Unmarshal([]byte(string(respBytes)), &jsonMap); err != nil {
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
}
return jsonMap, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -16,16 +16,20 @@ package fhirpatientsearch
import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/googleapi"
"google.golang.org/api/healthcare/v1"
)
const kind string = "cloud-healthcare-fhir-patient-search"
@@ -66,9 +70,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedFHIRStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
FHIRPatientSearch(string, string, []googleapi.CallOption) (any, error)
}
type Config struct {
@@ -161,9 +169,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
var summary bool
@@ -232,7 +248,26 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if summary {
opts = append(opts, googleapi.QueryParameter("_summary", "text"))
}
return source.FHIRPatientSearch(storeID, tokenStr, opts)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.SearchType(name, "Patient", &healthcare.SearchResourcesRequest{ResourceType: "Patient"}).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to search patient resources: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
var jsonMap map[string]interface{}
if err := json.Unmarshal([]byte(string(respBytes)), &jsonMap); err != nil {
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
}
return jsonMap, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -21,6 +21,7 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/healthcare/v1"
@@ -43,8 +44,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
GetDataset(string) (*healthcare.Dataset, error)
}
type Config struct {
@@ -95,11 +100,27 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
return source.GetDataset(tokenStr)
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
}
return dataset, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -21,6 +21,7 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -44,9 +45,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedDICOMStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
GetDICOMStore(string, string) (*healthcare.DicomStore, error)
}
type Config struct {
@@ -112,15 +117,31 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
return source.GetDICOMStore(storeID, tokenStr)
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.DicomStores.Get(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get DICOM store %q: %w", storeName, err)
}
return store, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -21,6 +21,7 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -44,9 +45,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedDICOMStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
GetDICOMStoreMetrics(string, string) (*healthcare.DicomStoreMetrics, error)
}
type Config struct {
@@ -112,15 +117,31 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
return source.GetDICOMStoreMetrics(storeID, tokenStr)
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.DicomStores.GetDICOMStoreMetrics(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get metrics for DICOM store %q: %w", storeName, err)
}
return store, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -16,14 +16,18 @@ package getfhirresource
import (
"context"
"encoding/json"
"fmt"
"io"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/healthcare/v1"
)
const kind string = "cloud-healthcare-get-fhir-resource"
@@ -47,9 +51,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedFHIRStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
GetFHIRResource(string, string, string, string) (any, error)
}
type Config struct {
@@ -126,15 +134,46 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if !ok {
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", typeKey)
}
resID, ok := params.AsMap()[idKey].(string)
if !ok {
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey)
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
return source.GetFHIRResource(storeID, resType, resID, tokenStr)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", source.Project(), source.Region(), source.DatasetID(), storeID, resType, resID)
call := svc.Projects.Locations.Datasets.FhirStores.Fhir.Read(name)
call.Header().Set("Content-Type", "application/fhir+json;charset=utf-8")
resp, err := call.Do()
if err != nil {
return nil, fmt.Errorf("failed to get fhir resource %q: %w", name, err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("read: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
var jsonMap map[string]interface{}
if err := json.Unmarshal([]byte(string(respBytes)), &jsonMap); err != nil {
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
}
return jsonMap, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -21,6 +21,7 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -44,9 +45,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedFHIRStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
GetFHIRStore(string, string) (*healthcare.FhirStore, error)
}
type Config struct {
@@ -112,15 +117,31 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
return source.GetFHIRStore(storeID, tokenStr)
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.FhirStores.Get(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get FHIR store %q: %w", storeName, err)
}
return store, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -21,6 +21,7 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -44,9 +45,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedFHIRStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
GetFHIRStoreMetrics(string, string) (*healthcare.FhirStoreMetrics, error)
}
type Config struct {
@@ -112,15 +117,31 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
return source.GetFHIRStoreMetrics(storeID, tokenStr)
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.FhirStores.GetFHIRStoreMetrics(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get metrics for FHIR store %q: %w", storeName, err)
}
return store, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -17,10 +17,12 @@ package listdicomstores
import (
"context"
"fmt"
"strings"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/healthcare/v1"
@@ -43,8 +45,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedDICOMStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
ListDICOMStores(tokenStr string) ([]*healthcare.DicomStore, error)
}
type Config struct {
@@ -95,11 +102,41 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
return source.ListDICOMStores(tokenStr)
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
}
var filtered []*healthcare.DicomStore
for _, store := range stores.DicomStores {
if len(source.AllowedDICOMStores()) == 0 {
filtered = append(filtered, store)
continue
}
if len(store.Name) == 0 {
continue
}
parts := strings.Split(store.Name, "/")
if _, ok := source.AllowedDICOMStores()[parts[len(parts)-1]]; ok {
filtered = append(filtered, store)
}
}
return filtered, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -17,10 +17,12 @@ package listfhirstores
import (
"context"
"fmt"
"strings"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/healthcare/v1"
@@ -43,8 +45,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedFHIRStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
ListFHIRStores(string) ([]*healthcare.FhirStore, error)
}
type Config struct {
@@ -95,11 +102,41 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
return source.ListFHIRStores(tokenStr)
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
}
var filtered []*healthcare.FhirStore
for _, store := range stores.FhirStores {
if len(source.AllowedFHIRStores()) == 0 {
filtered = append(filtered, store)
continue
}
if len(store.Name) == 0 {
continue
}
parts := strings.Split(store.Name, "/")
if _, ok := source.AllowedFHIRStores()[parts[len(parts)-1]]; ok {
filtered = append(filtered, store)
}
}
return filtered, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -16,14 +16,18 @@ package retrieverendereddicominstance
import (
"context"
"encoding/base64"
"fmt"
"io"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/healthcare/v1"
)
const kind string = "cloud-healthcare-retrieve-rendered-dicom-instance"
@@ -49,9 +53,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedDICOMStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
RetrieveRenderedDICOMInstance(string, string, string, string, int, string) (any, error)
}
type Config struct {
@@ -127,10 +135,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
study, ok := params.AsMap()[studyInstanceUIDKey].(string)
if !ok {
return nil, fmt.Errorf("invalid '%s' parameter; expected a string", studyInstanceUIDKey)
@@ -147,7 +165,25 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if !ok {
return nil, fmt.Errorf("invalid '%s' parameter; expected an integer", frameNumberKey)
}
return source.RetrieveRenderedDICOMInstance(storeID, study, series, sop, frame, tokenStr)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
dicomWebPath := fmt.Sprintf("studies/%s/series/%s/instances/%s/frames/%d/rendered", study, series, sop, frame)
call := svc.Projects.Locations.Datasets.DicomStores.Studies.Series.Instances.Frames.RetrieveRendered(name, dicomWebPath)
call.Header().Set("Accept", "image/jpeg")
resp, err := call.Do()
if err != nil {
return nil, fmt.Errorf("unable to retrieve dicom instance rendered image: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("RetrieveRendered: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
base64String := base64.StdEncoding.EncodeToString(respBytes)
return base64String, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -16,16 +16,20 @@ package searchdicominstances
import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/googleapi"
"google.golang.org/api/healthcare/v1"
)
const kind string = "cloud-healthcare-search-dicom-instances"
@@ -56,9 +60,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedDICOMStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
SearchDICOM(string, string, string, string, []googleapi.CallOption) (any, error)
}
type Config struct {
@@ -136,13 +144,23 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
opts, err := common.ParseDICOMSearchParameters(params, []string{sopInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})
@@ -173,7 +191,29 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
}
return source.SearchDICOM(t.Kind, storeID, dicomWebPath, tokenStr, opts)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to search dicom instances: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
if len(respBytes) == 0 {
return []interface{}{}, nil
}
var result []interface{}
if err := json.Unmarshal([]byte(string(respBytes)), &result); err != nil {
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
}
return result, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -16,15 +16,18 @@ package searchdicomseries
import (
"context"
"encoding/json"
"fmt"
"io"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/googleapi"
"google.golang.org/api/healthcare/v1"
)
const kind string = "cloud-healthcare-search-dicom-series"
@@ -54,9 +57,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedDICOMStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
SearchDICOM(string, string, string, string, []googleapi.CallOption) (any, error)
}
type Config struct {
@@ -138,9 +145,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
opts, err := common.ParseDICOMSearchParameters(params, []string{seriesInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})
@@ -158,7 +174,29 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
dicomWebPath = fmt.Sprintf("studies/%s/series", id)
}
}
return source.SearchDICOM(t.Kind, storeID, dicomWebPath, tokenStr, opts)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForSeries(name, dicomWebPath).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to search dicom series: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
if len(respBytes) == 0 {
return []interface{}{}, nil
}
var result []interface{}
if err := json.Unmarshal([]byte(string(respBytes)), &result); err != nil {
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
}
return result, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -16,15 +16,18 @@ package searchdicomstudies
import (
"context"
"encoding/json"
"fmt"
"io"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/googleapi"
"google.golang.org/api/healthcare/v1"
)
const kind string = "cloud-healthcare-search-dicom-studies"
@@ -52,9 +55,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedDICOMStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
SearchDICOM(string, string, string, string, []googleapi.CallOption) (any, error)
}
type Config struct {
@@ -129,20 +136,51 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
opts, err := common.ParseDICOMSearchParameters(params, []string{studyInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey})
if err != nil {
return nil, err
}
dicomWebPath := "studies"
return source.SearchDICOM(t.Kind, storeID, dicomWebPath, tokenStr, opts)
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForStudies(name, "studies").Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to search dicom studies: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
if len(respBytes) == 0 {
return []interface{}{}, nil
}
var result []interface{}
if err := json.Unmarshal([]byte(string(respBytes)), &result); err != nil {
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
}
return result, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -18,6 +18,7 @@ import (
"context"
"fmt"
dataplexapi "cloud.google.com/go/dataplex/apiv1"
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
@@ -43,7 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
LookupEntry(context.Context, string, int, []string, string) (*dataplexpb.Entry, error)
CatalogClient() *dataplexapi.CatalogClient
}
type Config struct {
@@ -117,6 +118,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
paramsMap := params.AsMap()
viewMap := map[int]dataplexpb.EntryView{
1: dataplexpb.EntryView_BASIC,
2: dataplexpb.EntryView_FULL,
3: dataplexpb.EntryView_CUSTOM,
4: dataplexpb.EntryView_ALL,
}
name, _ := paramsMap["name"].(string)
entry, _ := paramsMap["entry"].(string)
view, _ := paramsMap["view"].(int)
@@ -125,7 +132,19 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("can't convert aspectTypes to array of strings: %s", err)
}
aspectTypes := aspectTypeSlice.([]string)
return source.LookupEntry(ctx, name, view, aspectTypes, entry)
req := &dataplexpb.LookupEntryRequest{
Name: name,
View: viewMap[view],
AspectTypes: aspectTypes,
Entry: entry,
}
result, err := source.CatalogClient().LookupEntry(ctx, req)
if err != nil {
return nil, err
}
return result, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -18,7 +18,9 @@ import (
"context"
"fmt"
"cloud.google.com/go/dataplex/apiv1/dataplexpb"
dataplexapi "cloud.google.com/go/dataplex/apiv1"
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
"github.com/cenkalti/backoff/v5"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -43,7 +45,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
SearchAspectTypes(context.Context, string, int, string) ([]*dataplexpb.AspectType, error)
CatalogClient() *dataplexapi.CatalogClient
ProjectID() string
}
type Config struct {
@@ -98,11 +101,61 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
// Invoke the tool with the provided parameters
paramsMap := params.AsMap()
query, _ := paramsMap["query"].(string)
pageSize, _ := paramsMap["pageSize"].(int)
pageSize := int32(paramsMap["pageSize"].(int))
orderBy, _ := paramsMap["orderBy"].(string)
return source.SearchAspectTypes(ctx, query, pageSize, orderBy)
// Create SearchEntriesRequest with the provided parameters
req := &dataplexpb.SearchEntriesRequest{
Query: query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype",
Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()),
PageSize: pageSize,
OrderBy: orderBy,
SemanticSearch: true,
}
// Perform the search using the CatalogClient - this will return an iterator
it := source.CatalogClient().SearchEntries(ctx, req)
if it == nil {
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID())
}
// Create an instance of exponential backoff with default values for retrying GetAspectType calls
// InitialInterval, RandomizationFactor, Multiplier, MaxInterval = 500 ms, 0.5, 1.5, 60 s
getAspectBackOff := backoff.NewExponentialBackOff()
// Iterate through the search results and call GetAspectType for each result using the resource name
var results []*dataplexpb.AspectType
for {
entry, err := it.Next()
if err != nil {
break
}
resourceName := entry.DataplexEntry.GetEntrySource().Resource
getAspectTypeReq := &dataplexpb.GetAspectTypeRequest{
Name: resourceName,
}
operation := func() (*dataplexpb.AspectType, error) {
aspectType, err := source.CatalogClient().GetAspectType(ctx, getAspectTypeReq)
if err != nil {
return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err)
}
return aspectType, nil
}
// Retry the GetAspectType operation with exponential backoff
aspectType, err := backoff.Retry(ctx, operation, backoff.WithBackOff(getAspectBackOff))
if err != nil {
return nil, fmt.Errorf("failed to get aspect type after retries for entry %q: %w", resourceName, err)
}
results = append(results, aspectType)
}
return results, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -18,7 +18,8 @@ import (
"context"
"fmt"
"cloud.google.com/go/dataplex/apiv1/dataplexpb"
dataplexapi "cloud.google.com/go/dataplex/apiv1"
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -43,7 +44,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
SearchEntries(context.Context, string, int, string) ([]*dataplexpb.SearchEntriesResult, error)
CatalogClient() *dataplexapi.CatalogClient
ProjectID() string
}
type Config struct {
@@ -98,11 +100,34 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
query, _ := paramsMap["query"].(string)
pageSize, _ := paramsMap["pageSize"].(int)
pageSize := int32(paramsMap["pageSize"].(int))
orderBy, _ := paramsMap["orderBy"].(string)
return source.SearchEntries(ctx, query, pageSize, orderBy)
req := &dataplexpb.SearchEntriesRequest{
Query: query,
Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()),
PageSize: pageSize,
OrderBy: orderBy,
SemanticSearch: true,
}
it := source.CatalogClient().SearchEntries(ctx, req)
if it == nil {
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID())
}
var results []*dataplexpb.SearchEntriesResult
for {
entry, err := it.Next()
if err != nil {
break
}
results = append(results, entry)
}
return results, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -48,7 +48,6 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
AddDocuments(context.Context, string, any, bool) (map[string]any, error)
}
type Config struct {
@@ -135,20 +134,24 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
mapParams := params.AsMap()
// Get collection path
collectionPath, ok := mapParams[collectionPathKey].(string)
if !ok || collectionPath == "" {
return nil, fmt.Errorf("invalid or missing '%s' parameter", collectionPathKey)
}
// Validate collection path
if err := util.ValidateCollectionPath(collectionPath); err != nil {
return nil, fmt.Errorf("invalid collection path: %w", err)
}
// Get document data
documentDataRaw, ok := mapParams[documentDataKey]
if !ok {
return nil, fmt.Errorf("invalid or missing '%s' parameter", documentDataKey)
}
// Convert the document data from JSON format to Firestore format
// The client is passed to handle referenceValue types
documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
@@ -161,7 +164,30 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if val, ok := mapParams[returnDocumentDataKey].(bool); ok {
returnData = val
}
return source.AddDocuments(ctx, collectionPath, documentData, returnData)
// Get the collection reference
collection := source.FirestoreClient().Collection(collectionPath)
// Add the document to the collection
docRef, writeResult, err := collection.Add(ctx, documentData)
if err != nil {
return nil, fmt.Errorf("failed to add document: %w", err)
}
// Build the response
response := map[string]any{
"documentPath": docRef.Path,
"createTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
}
// Add document data if requested
if returnData {
// Convert the document data back to simple JSON format
simplifiedData := util.FirestoreValueToJSON(documentData)
response["documentData"] = simplifiedData
}
return response, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -46,7 +46,6 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
DeleteDocuments(context.Context, []string) ([]any, error)
}
type Config struct {
@@ -105,6 +104,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if !ok {
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected an array", documentPathsKey)
}
if len(documentPathsRaw) == 0 {
return nil, fmt.Errorf("'%s' parameter cannot be empty", documentPathsKey)
}
@@ -126,7 +126,45 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid document path at index %d: %w", i, err)
}
}
return source.DeleteDocuments(ctx, documentPaths)
// Create a BulkWriter to handle multiple deletions efficiently
bulkWriter := source.FirestoreClient().BulkWriter(ctx)
// Keep track of jobs for each document
jobs := make([]*firestoreapi.BulkWriterJob, len(documentPaths))
// Add all delete operations to the BulkWriter
for i, path := range documentPaths {
docRef := source.FirestoreClient().Doc(path)
job, err := bulkWriter.Delete(docRef)
if err != nil {
return nil, fmt.Errorf("failed to add delete operation for document %q: %w", path, err)
}
jobs[i] = job
}
// End the BulkWriter to execute all operations
bulkWriter.End()
// Collect results
results := make([]any, len(documentPaths))
for i, job := range jobs {
docData := make(map[string]any)
docData["path"] = documentPaths[i]
// Wait for the job to complete and get the result
_, err := job.Results()
if err != nil {
docData["success"] = false
docData["error"] = err.Error()
} else {
docData["success"] = true
}
results[i] = docData
}
return results, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -46,7 +46,6 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
GetDocuments(context.Context, []string) ([]any, error)
}
type Config struct {
@@ -127,7 +126,37 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid document path at index %d: %w", i, err)
}
}
return source.GetDocuments(ctx, documentPaths)
// Create document references from paths
docRefs := make([]*firestoreapi.DocumentRef, len(documentPaths))
for i, path := range documentPaths {
docRefs[i] = source.FirestoreClient().Doc(path)
}
// Get all documents
snapshots, err := source.FirestoreClient().GetAll(ctx, docRefs)
if err != nil {
return nil, fmt.Errorf("failed to get documents: %w", err)
}
// Convert snapshots to response data
results := make([]any, len(snapshots))
for i, snapshot := range snapshots {
docData := make(map[string]any)
docData["path"] = documentPaths[i]
docData["exists"] = snapshot.Exists()
if snapshot.Exists() {
docData["data"] = snapshot.Data()
docData["createTime"] = snapshot.CreateTime
docData["updateTime"] = snapshot.UpdateTime
docData["readTime"] = snapshot.ReadTime
}
results[i] = docData
}
return results, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -44,7 +44,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
FirebaseRulesClient() *firebaserules.Service
GetRules(context.Context) (any, error)
GetProjectId() string
GetDatabaseId() string
}
type Config struct {
@@ -97,7 +98,29 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
return source.GetRules(ctx)
// Get the latest release for Firestore
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", source.GetProjectId(), source.GetDatabaseId())
release, err := source.FirebaseRulesClient().Projects.Releases.Get(releaseName).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to get latest Firestore release: %w", err)
}
if release.RulesetName == "" {
return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", source.GetProjectId(), source.GetDatabaseId())
}
// Get the ruleset content
ruleset, err := source.FirebaseRulesClient().Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to get ruleset content: %w", err)
}
if ruleset.Source == nil || len(ruleset.Source.Files) == 0 {
return nil, fmt.Errorf("no rules files found in ruleset")
}
return ruleset, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -46,7 +46,6 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
ListCollections(context.Context, string) ([]any, error)
}
type Config struct {
@@ -103,15 +102,47 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
mapParams := params.AsMap()
var collectionRefs []*firestoreapi.CollectionRef
// Check if parentPath is provided
parentPath, _ := mapParams[parentPathKey].(string)
if parentPath != "" {
parentPath, hasParent := mapParams[parentPathKey].(string)
if hasParent && parentPath != "" {
// Validate parent document path
if err := util.ValidateDocumentPath(parentPath); err != nil {
return nil, fmt.Errorf("invalid parent document path: %w", err)
}
// List subcollections of the specified document
docRef := source.FirestoreClient().Doc(parentPath)
collectionRefs, err = docRef.Collections(ctx).GetAll()
if err != nil {
return nil, fmt.Errorf("failed to list subcollections of document %q: %w", parentPath, err)
}
} else {
// List root collections
collectionRefs, err = source.FirestoreClient().Collections(ctx).GetAll()
if err != nil {
return nil, fmt.Errorf("failed to list root collections: %w", err)
}
}
return source.ListCollections(ctx, parentPath)
// Convert collection references to response data
results := make([]any, len(collectionRefs))
for i, collRef := range collectionRefs {
collData := make(map[string]any)
collData["id"] = collRef.ID
collData["path"] = collRef.Path
// If this is a subcollection, include parent information
if collRef.Parent != nil {
collData["parent"] = collRef.Parent.Path
}
results[i] = collData
}
return results, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -36,6 +36,27 @@ const (
defaultLimit = 100
)
// Firestore operators
var validOperators = map[string]bool{
"<": true,
"<=": true,
">": true,
">=": true,
"==": true,
"!=": true,
"array-contains": true,
"array-contains-any": true,
"in": true,
"not-in": true,
}
// Error messages
const (
errFilterParseFailed = "failed to parse filters: %w"
errQueryExecutionFailed = "failed to execute query: %w"
errLimitParseFailed = "failed to parse limit value '%s': %w"
)
func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
@@ -53,8 +74,6 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
// compatibleSource defines the interface for sources that can provide a Firestore client
type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
BuildQuery(string, firestoreapi.EntityFilter, []string, string, firestoreapi.Direction, int, bool) (*firestoreapi.Query, error)
ExecuteQuery(context.Context, *firestoreapi.Query, bool) (any, error)
}
// Config represents the configuration for the Firestore query tool
@@ -120,6 +139,15 @@ func (t Tool) ToConfig() tools.ToolConfig {
return t.Config
}
// SimplifiedFilter represents the simplified filter format
type SimplifiedFilter struct {
And []SimplifiedFilter `json:"and,omitempty"`
Or []SimplifiedFilter `json:"or,omitempty"`
Field string `json:"field,omitempty"`
Op string `json:"op,omitempty"`
Value interface{} `json:"value,omitempty"`
}
// OrderByConfig represents ordering configuration
type OrderByConfig struct {
Field string `json:"field"`
@@ -134,27 +162,20 @@ func (o *OrderByConfig) GetDirection() firestoreapi.Direction {
return firestoreapi.Asc
}
// SimplifiedFilter represents the simplified filter format
type SimplifiedFilter struct {
And []SimplifiedFilter `json:"and,omitempty"`
Or []SimplifiedFilter `json:"or,omitempty"`
Field string `json:"field,omitempty"`
Op string `json:"op,omitempty"`
Value interface{} `json:"value,omitempty"`
// QueryResult represents a document result from the query
type QueryResult struct {
ID string `json:"id"`
Path string `json:"path"`
Data map[string]any `json:"data"`
CreateTime interface{} `json:"createTime,omitempty"`
UpdateTime interface{} `json:"updateTime,omitempty"`
ReadTime interface{} `json:"readTime,omitempty"`
}
// Firestore operators
var validOperators = map[string]bool{
"<": true,
"<=": true,
">": true,
">=": true,
"==": true,
"!=": true,
"array-contains": true,
"array-contains-any": true,
"in": true,
"not-in": true,
// QueryResponse represents the full response including optional metrics
type QueryResponse struct {
Documents []QueryResult `json:"documents"`
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
}
// Invoke executes the Firestore query based on the provided parameters
@@ -163,18 +184,34 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
// Process collection path with template substitution
collectionPath, err := parameters.PopulateTemplate("collectionPath", t.CollectionPath, paramsMap)
if err != nil {
return nil, fmt.Errorf("failed to process collection path: %w", err)
}
var filter firestoreapi.EntityFilter
// Build the query
query, err := t.buildQuery(source, collectionPath, paramsMap)
if err != nil {
return nil, err
}
// Execute the query and return results
return t.executeQuery(ctx, query)
}
// buildQuery constructs the Firestore query from parameters
func (t Tool) buildQuery(source compatibleSource, collectionPath string, params map[string]any) (*firestoreapi.Query, error) {
collection := source.FirestoreClient().Collection(collectionPath)
query := collection.Query
// Process and apply filters if template is provided
if t.Filters != "" {
// Apply template substitution to filters
filtersJSON, err := parameters.PopulateTemplateWithJSON("filters", t.Filters, paramsMap)
filtersJSON, err := parameters.PopulateTemplateWithJSON("filters", t.Filters, params)
if err != nil {
return nil, fmt.Errorf("failed to process filters template: %w", err)
}
@@ -182,43 +219,48 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// Parse the simplified filter format
var simplifiedFilter SimplifiedFilter
if err := json.Unmarshal([]byte(filtersJSON), &simplifiedFilter); err != nil {
return nil, fmt.Errorf("failed to parse filters: %w", err)
return nil, fmt.Errorf(errFilterParseFailed, err)
}
// Convert simplified filter to Firestore filter
filter = t.convertToFirestoreFilter(source, simplifiedFilter)
}
// Process and apply ordering
orderBy, err := t.getOrderBy(paramsMap)
if err != nil {
return nil, err
if filter := t.convertToFirestoreFilter(source, simplifiedFilter); filter != nil {
query = query.WhereEntity(filter)
}
}
// Process select fields
selectFields, err := t.processSelectFields(paramsMap)
selectFields, err := t.processSelectFields(params)
if err != nil {
return nil, err
}
// Process and apply limit
limit, err := t.getLimit(paramsMap)
if err != nil {
return nil, err
if len(selectFields) > 0 {
query = query.Select(selectFields...)
}
// prevent panic when accessing orderBy incase it is nil
var orderByField string
var orderByDirection firestoreapi.Direction
// Process and apply ordering
orderBy, err := t.getOrderBy(params)
if err != nil {
return nil, err
}
if orderBy != nil {
orderByField = orderBy.Field
orderByDirection = orderBy.GetDirection()
query = query.OrderBy(orderBy.Field, orderBy.GetDirection())
}
// Build the query
query, err := source.BuildQuery(collectionPath, filter, selectFields, orderByField, orderByDirection, limit, t.AnalyzeQuery)
// Process and apply limit
limit, err := t.getLimit(params)
if err != nil {
return nil, err
}
// Execute the query and return results
return source.ExecuteQuery(ctx, query, t.AnalyzeQuery)
query = query.Limit(limit)
// Apply analyze options if enabled
if t.AnalyzeQuery {
query = query.WithRunOptions(firestoreapi.ExplainOptions{
Analyze: true,
})
}
return &query, nil
}
// convertToFirestoreFilter converts simplified filter format to Firestore EntityFilter
@@ -367,7 +409,7 @@ func (t Tool) getLimit(params map[string]any) (int, error) {
if processedValue != "" {
parsedLimit, err := strconv.Atoi(processedValue)
if err != nil {
return 0, fmt.Errorf("failed to parse limit value '%s': %w", processedValue, err)
return 0, fmt.Errorf(errLimitParseFailed, processedValue, err)
}
limit = parsedLimit
}
@@ -375,6 +417,78 @@ func (t Tool) getLimit(params map[string]any) (int, error) {
return limit, nil
}
// executeQuery runs the query and formats the results
func (t Tool) executeQuery(ctx context.Context, query *firestoreapi.Query) (any, error) {
docIterator := query.Documents(ctx)
docs, err := docIterator.GetAll()
if err != nil {
return nil, fmt.Errorf(errQueryExecutionFailed, err)
}
// Convert results to structured format
results := make([]QueryResult, len(docs))
for i, doc := range docs {
results[i] = QueryResult{
ID: doc.Ref.ID,
Path: doc.Ref.Path,
Data: doc.Data(),
CreateTime: doc.CreateTime,
UpdateTime: doc.UpdateTime,
ReadTime: doc.ReadTime,
}
}
// Return with explain metrics if requested
if t.AnalyzeQuery {
explainMetrics, err := t.getExplainMetrics(docIterator)
if err == nil && explainMetrics != nil {
response := QueryResponse{
Documents: results,
ExplainMetrics: explainMetrics,
}
return response, nil
}
}
return results, nil
}
// getExplainMetrics extracts explain metrics from the query iterator
func (t Tool) getExplainMetrics(docIterator *firestoreapi.DocumentIterator) (map[string]any, error) {
explainMetrics, err := docIterator.ExplainMetrics()
if err != nil || explainMetrics == nil {
return nil, err
}
metricsData := make(map[string]any)
// Add plan summary if available
if explainMetrics.PlanSummary != nil {
planSummary := make(map[string]any)
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
metricsData["planSummary"] = planSummary
}
// Add execution stats if available
if explainMetrics.ExecutionStats != nil {
executionStats := make(map[string]any)
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
}
if explainMetrics.ExecutionStats.DebugStats != nil {
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
}
metricsData["executionStats"] = executionStats
}
return metricsData, nil
}
// ParseParams parses and validates input parameters
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
return parameters.ParseParams(t.Parameters, data, claims)

View File

@@ -69,6 +69,7 @@ const (
errInvalidOperator = "unsupported operator: %s. Valid operators are: %v"
errMissingFilterValue = "no value specified for filter on field '%s'"
errOrderByParseFailed = "failed to parse orderBy: %w"
errQueryExecutionFailed = "failed to execute query: %w"
errTooManyFilters = "too many filters provided: %d (maximum: %d)"
)
@@ -89,8 +90,6 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
// compatibleSource defines the interface for sources that can provide a Firestore client
type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
BuildQuery(string, firestoreapi.EntityFilter, []string, string, firestoreapi.Direction, int, bool) (*firestoreapi.Query, error)
ExecuteQuery(context.Context, *firestoreapi.Query, bool) (any, error)
}
// Config represents the configuration for the Firestore query collection tool
@@ -229,6 +228,22 @@ func (o *OrderByConfig) GetDirection() firestoreapi.Direction {
return firestoreapi.Asc
}
// QueryResult represents a document result from the query
type QueryResult struct {
ID string `json:"id"`
Path string `json:"path"`
Data map[string]any `json:"data"`
CreateTime interface{} `json:"createTime,omitempty"`
UpdateTime interface{} `json:"updateTime,omitempty"`
ReadTime interface{} `json:"readTime,omitempty"`
}
// QueryResponse represents the full response including optional metrics
type QueryResponse struct {
Documents []QueryResult `json:"documents"`
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
}
// Invoke executes the Firestore query based on the provided parameters
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
@@ -242,37 +257,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, err
}
var filter firestoreapi.EntityFilter
// Apply filters
if len(queryParams.Filters) > 0 {
filterConditions := make([]firestoreapi.EntityFilter, 0, len(queryParams.Filters))
for _, filter := range queryParams.Filters {
filterConditions = append(filterConditions, firestoreapi.PropertyFilter{
Path: filter.Field,
Operator: filter.Op,
Value: filter.Value,
})
}
filter = firestoreapi.AndFilter{
Filters: filterConditions,
}
}
// prevent panic incase queryParams.OrderBy is nil
var orderByField string
var orderByDirection firestoreapi.Direction
if queryParams.OrderBy != nil {
orderByField = queryParams.OrderBy.Field
orderByDirection = queryParams.OrderBy.GetDirection()
}
// Build the query
query, err := source.BuildQuery(queryParams.CollectionPath, filter, nil, orderByField, orderByDirection, queryParams.Limit, queryParams.AnalyzeQuery)
query, err := t.buildQuery(source, queryParams)
if err != nil {
return nil, err
}
return source.ExecuteQuery(ctx, query, queryParams.AnalyzeQuery)
// Execute the query and return results
return t.executeQuery(ctx, query, queryParams.AnalyzeQuery)
}
// queryParameters holds all parsed query parameters
@@ -388,6 +380,122 @@ func (t Tool) parseOrderBy(orderByRaw interface{}) (*OrderByConfig, error) {
return &orderBy, nil
}
// buildQuery constructs the Firestore query from parameters
func (t Tool) buildQuery(source compatibleSource, params *queryParameters) (*firestoreapi.Query, error) {
collection := source.FirestoreClient().Collection(params.CollectionPath)
query := collection.Query
// Apply filters
if len(params.Filters) > 0 {
filterConditions := make([]firestoreapi.EntityFilter, 0, len(params.Filters))
for _, filter := range params.Filters {
filterConditions = append(filterConditions, firestoreapi.PropertyFilter{
Path: filter.Field,
Operator: filter.Op,
Value: filter.Value,
})
}
query = query.WhereEntity(firestoreapi.AndFilter{
Filters: filterConditions,
})
}
// Apply ordering
if params.OrderBy != nil {
query = query.OrderBy(params.OrderBy.Field, params.OrderBy.GetDirection())
}
// Apply limit
query = query.Limit(params.Limit)
// Apply analyze options
if params.AnalyzeQuery {
query = query.WithRunOptions(firestoreapi.ExplainOptions{
Analyze: true,
})
}
return &query, nil
}
// executeQuery runs the query and formats the results
func (t Tool) executeQuery(ctx context.Context, query *firestoreapi.Query, analyzeQuery bool) (any, error) {
docIterator := query.Documents(ctx)
docs, err := docIterator.GetAll()
if err != nil {
return nil, fmt.Errorf(errQueryExecutionFailed, err)
}
// Convert results to structured format
results := make([]QueryResult, len(docs))
for i, doc := range docs {
results[i] = QueryResult{
ID: doc.Ref.ID,
Path: doc.Ref.Path,
Data: doc.Data(),
CreateTime: doc.CreateTime,
UpdateTime: doc.UpdateTime,
ReadTime: doc.ReadTime,
}
}
// Return with explain metrics if requested
if analyzeQuery {
explainMetrics, err := t.getExplainMetrics(docIterator)
if err == nil && explainMetrics != nil {
response := QueryResponse{
Documents: results,
ExplainMetrics: explainMetrics,
}
return response, nil
}
}
// Return just the documents
resultsAny := make([]any, len(results))
for i, r := range results {
resultsAny[i] = r
}
return resultsAny, nil
}
// getExplainMetrics extracts explain metrics from the query iterator
func (t Tool) getExplainMetrics(docIterator *firestoreapi.DocumentIterator) (map[string]any, error) {
explainMetrics, err := docIterator.ExplainMetrics()
if err != nil || explainMetrics == nil {
return nil, err
}
metricsData := make(map[string]any)
// Add plan summary if available
if explainMetrics.PlanSummary != nil {
planSummary := make(map[string]any)
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
metricsData["planSummary"] = planSummary
}
// Add execution stats if available
if explainMetrics.ExecutionStats != nil {
executionStats := make(map[string]any)
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
}
if explainMetrics.ExecutionStats.DebugStats != nil {
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
}
metricsData["executionStats"] = executionStats
}
return metricsData, nil
}
// ParseParams parses and validates input parameters
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
return parameters.ParseParams(t.Parameters, data, claims)

View File

@@ -50,7 +50,6 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
UpdateDocument(context.Context, string, []firestoreapi.Update, any, bool) (map[string]any, error)
}
type Config struct {
@@ -178,10 +177,23 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
}
// Use selective field update with update mask
updates := make([]firestoreapi.Update, 0, len(updatePaths))
var documentData any
// Get return document data flag
returnData := false
if val, ok := mapParams[returnDocumentDataKey].(bool); ok {
returnData = val
}
// Get the document reference
docRef := source.FirestoreClient().Doc(documentPath)
// Prepare update data
var writeResult *firestoreapi.WriteResult
var writeErr error
if len(updatePaths) > 0 {
// Use selective field update with update mask
updates := make([]firestoreapi.Update, 0, len(updatePaths))
// Convert document data without delete markers
dataMap, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
@@ -208,20 +220,41 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
Value: value,
})
}
writeResult, writeErr = docRef.Update(ctx, updates)
} else {
// Update all fields in the document data (merge)
documentData, err = util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
if err != nil {
return nil, fmt.Errorf("failed to convert document data: %w", err)
}
writeResult, writeErr = docRef.Set(ctx, documentData, firestoreapi.MergeAll)
}
// Get return document data flag
returnData := false
if val, ok := mapParams[returnDocumentDataKey].(bool); ok {
returnData = val
if writeErr != nil {
return nil, fmt.Errorf("failed to update document: %w", writeErr)
}
return source.UpdateDocument(ctx, documentPath, updates, documentData, returnData)
// Build the response
response := map[string]any{
"documentPath": docRef.Path,
"updateTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
}
// Add document data if requested
if returnData {
// Fetch the updated document to return the current state
snapshot, err := docRef.Get(ctx)
if err != nil {
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
}
// Convert the document data to simple JSON format
simplifiedData := util.FirestoreValueToJSON(snapshot.Data())
response["documentData"] = simplifiedData
}
return response, nil
}
// getFieldValue retrieves a value from a nested map using a dot-separated path

View File

@@ -17,6 +17,7 @@ package firestorevalidaterules
import (
"context"
"fmt"
"strings"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
@@ -49,7 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
FirebaseRulesClient() *firebaserules.Service
ValidateRules(context.Context, string) (any, error)
GetProjectId() string
}
type Config struct {
@@ -106,6 +107,30 @@ func (t Tool) ToConfig() tools.ToolConfig {
return t.Config
}
// Issue represents a validation issue in the rules
type Issue struct {
SourcePosition SourcePosition `json:"sourcePosition"`
Description string `json:"description"`
Severity string `json:"severity"`
}
// SourcePosition represents the location of an issue in the source
type SourcePosition struct {
FileName string `json:"fileName,omitempty"`
Line int64 `json:"line"` // 1-based
Column int64 `json:"column"` // 1-based
CurrentOffset int64 `json:"currentOffset"` // 0-based, inclusive start
EndOffset int64 `json:"endOffset"` // 0-based, exclusive end
}
// ValidationResult represents the result of rules validation
type ValidationResult struct {
Valid bool `json:"valid"`
IssueCount int `json:"issueCount"`
FormattedIssues string `json:"formattedIssues,omitempty"`
RawIssues []Issue `json:"rawIssues,omitempty"`
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
@@ -119,7 +144,114 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if !ok || sourceParam == "" {
return nil, fmt.Errorf("invalid or missing '%s' parameter", sourceKey)
}
return source.ValidateRules(ctx, sourceParam)
// Create test request
testRequest := &firebaserules.TestRulesetRequest{
Source: &firebaserules.Source{
Files: []*firebaserules.File{
{
Name: "firestore.rules",
Content: sourceParam,
},
},
},
// We don't need test cases for validation only
TestSuite: &firebaserules.TestSuite{
TestCases: []*firebaserules.TestCase{},
},
}
// Call the test API
projectName := fmt.Sprintf("projects/%s", source.GetProjectId())
response, err := source.FirebaseRulesClient().Projects.Test(projectName, testRequest).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to validate rules: %w", err)
}
// Process the response
result := t.processValidationResponse(response, sourceParam)
return result, nil
}
func (t Tool) processValidationResponse(response *firebaserules.TestRulesetResponse, source string) ValidationResult {
if len(response.Issues) == 0 {
return ValidationResult{
Valid: true,
IssueCount: 0,
FormattedIssues: "✓ No errors detected. Rules are valid.",
}
}
// Convert issues to our format
issues := make([]Issue, len(response.Issues))
for i, issue := range response.Issues {
issues[i] = Issue{
Description: issue.Description,
Severity: issue.Severity,
SourcePosition: SourcePosition{
FileName: issue.SourcePosition.FileName,
Line: issue.SourcePosition.Line,
Column: issue.SourcePosition.Column,
CurrentOffset: issue.SourcePosition.CurrentOffset,
EndOffset: issue.SourcePosition.EndOffset,
},
}
}
// Format issues
formattedIssues := t.formatRulesetIssues(issues, source)
return ValidationResult{
Valid: false,
IssueCount: len(issues),
FormattedIssues: formattedIssues,
RawIssues: issues,
}
}
// formatRulesetIssues formats validation issues into a human-readable string with code snippets
func (t Tool) formatRulesetIssues(issues []Issue, rulesSource string) string {
sourceLines := strings.Split(rulesSource, "\n")
var formattedOutput []string
formattedOutput = append(formattedOutput, fmt.Sprintf("Found %d issue(s) in rules source:\n", len(issues)))
for _, issue := range issues {
issueString := fmt.Sprintf("%s: %s [Ln %d, Col %d]",
issue.Severity,
issue.Description,
issue.SourcePosition.Line,
issue.SourcePosition.Column)
if issue.SourcePosition.Line > 0 {
lineIndex := int(issue.SourcePosition.Line - 1) // 0-based index
if lineIndex >= 0 && lineIndex < len(sourceLines) {
errorLine := sourceLines[lineIndex]
issueString += fmt.Sprintf("\n```\n%s", errorLine)
// Add carets if we have column and offset information
if issue.SourcePosition.Column > 0 &&
issue.SourcePosition.CurrentOffset >= 0 &&
issue.SourcePosition.EndOffset > issue.SourcePosition.CurrentOffset {
startColumn := int(issue.SourcePosition.Column - 1) // 0-based
errorTokenLength := int(issue.SourcePosition.EndOffset - issue.SourcePosition.CurrentOffset)
if startColumn >= 0 && errorTokenLength > 0 && startColumn <= len(errorLine) {
padding := strings.Repeat(" ", startColumn)
carets := strings.Repeat("^", errorTokenLength)
issueString += fmt.Sprintf("\n%s%s", padding, carets)
}
}
issueString += "\n```"
}
}
formattedOutput = append(formattedOutput, issueString)
}
return strings.Join(formattedOutput, "\n\n")
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -28,13 +28,13 @@ import (
// JSONToFirestoreValue converts a JSON value with type information to a Firestore-compatible value
// The input should be a map with a single key indicating the type (e.g., "stringValue", "integerValue")
// If a client is provided, referenceValue types will be converted to *firestore.DocumentRef
func JSONToFirestoreValue(value any, client *firestore.Client) (any, error) {
func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interface{}, error) {
if value == nil {
return nil, nil
}
switch v := value.(type) {
case map[string]any:
case map[string]interface{}:
// Check for typed values
if len(v) == 1 {
for key, val := range v {
@@ -92,7 +92,7 @@ func JSONToFirestoreValue(value any, client *firestore.Client) (any, error) {
return nil, fmt.Errorf("timestamp value must be a string")
case "geoPointValue":
// Convert to LatLng
if geoMap, ok := val.(map[string]any); ok {
if geoMap, ok := val.(map[string]interface{}); ok {
lat, latOk := geoMap["latitude"].(float64)
lng, lngOk := geoMap["longitude"].(float64)
if latOk && lngOk {
@@ -105,9 +105,9 @@ func JSONToFirestoreValue(value any, client *firestore.Client) (any, error) {
return nil, fmt.Errorf("invalid geopoint value format")
case "arrayValue":
// Convert array
if arrayMap, ok := val.(map[string]any); ok {
if values, ok := arrayMap["values"].([]any); ok {
result := make([]any, len(values))
if arrayMap, ok := val.(map[string]interface{}); ok {
if values, ok := arrayMap["values"].([]interface{}); ok {
result := make([]interface{}, len(values))
for i, item := range values {
converted, err := JSONToFirestoreValue(item, client)
if err != nil {
@@ -121,9 +121,9 @@ func JSONToFirestoreValue(value any, client *firestore.Client) (any, error) {
return nil, fmt.Errorf("invalid array value format")
case "mapValue":
// Convert map
if mapMap, ok := val.(map[string]any); ok {
if fields, ok := mapMap["fields"].(map[string]any); ok {
result := make(map[string]any)
if mapMap, ok := val.(map[string]interface{}); ok {
if fields, ok := mapMap["fields"].(map[string]interface{}); ok {
result := make(map[string]interface{})
for k, v := range fields {
converted, err := JSONToFirestoreValue(v, client)
if err != nil {
@@ -160,8 +160,8 @@ func JSONToFirestoreValue(value any, client *firestore.Client) (any, error) {
}
// convertPlainMap converts a plain map to Firestore format
func convertPlainMap(m map[string]any, client *firestore.Client) (map[string]any, error) {
result := make(map[string]any)
func convertPlainMap(m map[string]interface{}, client *firestore.Client) (map[string]interface{}, error) {
result := make(map[string]interface{})
for k, v := range m {
converted, err := JSONToFirestoreValue(v, client)
if err != nil {
@@ -172,6 +172,42 @@ func convertPlainMap(m map[string]any, client *firestore.Client) (map[string]any
return result, nil
}
// FirestoreValueToJSON converts a Firestore value to a simplified JSON representation
// This removes type information and returns plain values
func FirestoreValueToJSON(value interface{}) interface{} {
if value == nil {
return nil
}
switch v := value.(type) {
case time.Time:
return v.Format(time.RFC3339Nano)
case *latlng.LatLng:
return map[string]interface{}{
"latitude": v.Latitude,
"longitude": v.Longitude,
}
case []byte:
return base64.StdEncoding.EncodeToString(v)
case []interface{}:
result := make([]interface{}, len(v))
for i, item := range v {
result[i] = FirestoreValueToJSON(item)
}
return result
case map[string]interface{}:
result := make(map[string]interface{})
for k, val := range v {
result[k] = FirestoreValueToJSON(val)
}
return result
case *firestore.DocumentRef:
return v.Path
default:
return value
}
}
// isValidDocumentPath checks if a string is a valid Firestore document path
// Valid paths have an even number of segments (collection/doc/collection/doc...)
func isValidDocumentPath(path string) bool {

View File

@@ -312,6 +312,40 @@ func TestJSONToFirestoreValue_IntegerFromString(t *testing.T) {
}
}
func TestFirestoreValueToJSON_RoundTrip(t *testing.T) {
// Test round-trip conversion
original := map[string]interface{}{
"name": "Test",
"count": int64(42),
"price": 19.99,
"active": true,
"tags": []interface{}{"tag1", "tag2"},
"metadata": map[string]interface{}{
"created": time.Now(),
},
"nullField": nil,
}
// Convert to JSON representation
jsonRepresentation := FirestoreValueToJSON(original)
// Verify types are simplified
jsonMap, ok := jsonRepresentation.(map[string]interface{})
if !ok {
t.Fatalf("Expected map, got %T", jsonRepresentation)
}
// Time should be converted to string
metadata, ok := jsonMap["metadata"].(map[string]interface{})
if !ok {
t.Fatalf("metadata should be a map, got %T", jsonMap["metadata"])
}
_, ok = metadata["created"].(string)
if !ok {
t.Errorf("created should be a string, got %T", metadata["created"])
}
}
func TestJSONToFirestoreValue_InvalidFormats(t *testing.T) {
tests := []struct {
name string

View File

@@ -16,7 +16,9 @@ package http
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"slices"
@@ -52,7 +54,7 @@ type compatibleSource interface {
HttpDefaultHeaders() map[string]string
HttpBaseURL() string
HttpQueryParams() map[string]string
RunRequest(*http.Request) (any, error)
Client() *http.Client
}
type Config struct {
@@ -257,7 +259,29 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
for k, v := range allHeaders {
req.Header.Set(k, v)
}
return source.RunRequest(req)
// Make request and fetch response
resp, err := source.Client().Do(req)
if err != nil {
return nil, fmt.Errorf("error making HTTP request: %s", err)
}
defer resp.Body.Close()
var body []byte
body, err = io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode > 299 {
return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body))
}
var data any
if err = json.Unmarshal(body, &data); err != nil {
// if unable to unmarshal data, return result as string.
return string(body), nil
}
return data, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -1,10 +1,10 @@
// Copyright 2026 Google LLC
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package serverlessspark
package common
import (
"fmt"
@@ -23,13 +23,13 @@ import (
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
)
var batchFullNameRegex = regexp.MustCompile(`projects/(?P<project>[^/]+)/locations/(?P<location>[^/]+)/batches/(?P<batch_id>[^/]+)`)
const (
logTimeBufferBefore = 1 * time.Minute
logTimeBufferAfter = 10 * time.Minute
)
var batchFullNameRegex = regexp.MustCompile(`projects/(?P<project>[^/]+)/locations/(?P<location>[^/]+)/batches/(?P<batch_id>[^/]+)`)
// Extract BatchDetails extracts the project ID, location, and batch ID from a fully qualified batch name.
func ExtractBatchDetails(batchName string) (projectID, location, batchID string, err error) {
matches := batchFullNameRegex.FindStringSubmatch(batchName)
@@ -39,6 +39,26 @@ func ExtractBatchDetails(batchName string) (projectID, location, batchID string,
return matches[1], matches[2], matches[3], nil
}
// BatchConsoleURLFromProto builds a URL to the Google Cloud Console linking to the batch summary page.
func BatchConsoleURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
if err != nil {
return "", err
}
return BatchConsoleURL(projectID, location, batchID), nil
}
// BatchLogsURLFromProto builds a URL to the Google Cloud Console showing Cloud Logging for the given batch and time range.
func BatchLogsURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
if err != nil {
return "", err
}
createTime := batchPb.GetCreateTime().AsTime()
stateTime := batchPb.GetStateTime().AsTime()
return BatchLogsURL(projectID, location, batchID, createTime, stateTime), nil
}
// BatchConsoleURL builds a URL to the Google Cloud Console linking to the batch summary page.
func BatchConsoleURL(projectID, location, batchID string) string {
return fmt.Sprintf("https://console.cloud.google.com/dataproc/batches/%s/%s/summary?project=%s", location, batchID, projectID)
@@ -69,23 +89,3 @@ resource.labels.batch_id="%s"`
return "https://console.cloud.google.com/logs/viewer?" + v.Encode()
}
// BatchConsoleURLFromProto builds a URL to the Google Cloud Console linking to the batch summary page.
func BatchConsoleURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
if err != nil {
return "", err
}
return BatchConsoleURL(projectID, location, batchID), nil
}
// BatchLogsURLFromProto builds a URL to the Google Cloud Console showing Cloud Logging for the given batch and time range.
func BatchLogsURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
if err != nil {
return "", err
}
createTime := batchPb.GetCreateTime().AsTime()
stateTime := batchPb.GetStateTime().AsTime()
return BatchLogsURL(projectID, location, batchID, createTime, stateTime), nil
}

View File

@@ -1,10 +1,10 @@
// Copyright 2026 Google LLC
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
@@ -12,20 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package serverlessspark_test
package common
import (
"testing"
"time"
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
"github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
"google.golang.org/protobuf/types/known/timestamppb"
)
func TestExtractBatchDetails_Success(t *testing.T) {
batchName := "projects/my-project/locations/us-central1/batches/my-batch"
projectID, location, batchID, err := serverlessspark.ExtractBatchDetails(batchName)
projectID, location, batchID, err := ExtractBatchDetails(batchName)
if err != nil {
t.Errorf("ExtractBatchDetails() error = %v, want no error", err)
return
@@ -46,7 +45,7 @@ func TestExtractBatchDetails_Success(t *testing.T) {
func TestExtractBatchDetails_Failure(t *testing.T) {
batchName := "invalid-name"
_, _, _, err := serverlessspark.ExtractBatchDetails(batchName)
_, _, _, err := ExtractBatchDetails(batchName)
wantErr := "failed to parse batch name: invalid-name"
if err == nil || err.Error() != wantErr {
t.Errorf("ExtractBatchDetails() error = %v, want %v", err, wantErr)
@@ -54,7 +53,7 @@ func TestExtractBatchDetails_Failure(t *testing.T) {
}
func TestBatchConsoleURL(t *testing.T) {
got := serverlessspark.BatchConsoleURL("my-project", "us-central1", "my-batch")
got := BatchConsoleURL("my-project", "us-central1", "my-batch")
want := "https://console.cloud.google.com/dataproc/batches/us-central1/my-batch/summary?project=my-project"
if got != want {
t.Errorf("BatchConsoleURL() = %v, want %v", got, want)
@@ -64,7 +63,7 @@ func TestBatchConsoleURL(t *testing.T) {
func TestBatchLogsURL(t *testing.T) {
startTime := time.Date(2025, 10, 1, 5, 0, 0, 0, time.UTC)
endTime := time.Date(2025, 10, 1, 6, 0, 0, 0, time.UTC)
got := serverlessspark.BatchLogsURL("my-project", "us-central1", "my-batch", startTime, endTime)
got := BatchLogsURL("my-project", "us-central1", "my-batch", startTime, endTime)
want := "https://console.cloud.google.com/logs/viewer?advancedFilter=" +
"resource.type%3D%22cloud_dataproc_batch%22" +
"%0Aresource.labels.project_id%3D%22my-project%22" +
@@ -83,7 +82,7 @@ func TestBatchConsoleURLFromProto(t *testing.T) {
batchPb := &dataprocpb.Batch{
Name: "projects/my-project/locations/us-central1/batches/my-batch",
}
got, err := serverlessspark.BatchConsoleURLFromProto(batchPb)
got, err := BatchConsoleURLFromProto(batchPb)
if err != nil {
t.Fatalf("BatchConsoleURLFromProto() error = %v", err)
}
@@ -101,7 +100,7 @@ func TestBatchLogsURLFromProto(t *testing.T) {
CreateTime: timestamppb.New(createTime),
StateTime: timestamppb.New(stateTime),
}
got, err := serverlessspark.BatchLogsURLFromProto(batchPb)
got, err := BatchLogsURLFromProto(batchPb)
if err != nil {
t.Fatalf("BatchLogsURLFromProto() error = %v", err)
}

View File

@@ -19,6 +19,7 @@ import (
"encoding/json"
"fmt"
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
dataprocpb "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
"github.com/goccy/go-yaml"
"google.golang.org/protobuf/encoding/protojson"
@@ -35,7 +36,9 @@ func unmarshalProto(data any, m proto.Message) error {
}
type compatibleSource interface {
CreateBatch(context.Context, *dataprocpb.Batch) (map[string]any, error)
GetBatchControllerClient() *dataproc.BatchControllerClient
GetProject() string
GetLocation() string
}
// Config is a common config that can be used with any type of create batch tool. However, each tool

View File

@@ -16,19 +16,23 @@ package createbatch
import (
"context"
"encoding/json"
"fmt"
"time"
dataprocpb "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)
type BatchBuilder interface {
Parameters() parameters.Parameters
BuildBatch(parameters.ParamValues) (*dataprocpb.Batch, error)
BuildBatch(params parameters.ParamValues) (*dataprocpb.Batch, error)
}
func NewTool(cfg Config, originalCfg tools.ToolConfig, srcs map[string]sources.Source, builder BatchBuilder) (*Tool, error) {
@@ -70,6 +74,7 @@ func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, par
if err != nil {
return nil, err
}
client := source.GetBatchControllerClient()
batch, err := t.Builder.BuildBatch(params)
if err != nil {
@@ -92,7 +97,46 @@ func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, par
}
batch.RuntimeConfig.Version = version
}
return source.CreateBatch(ctx, batch)
req := &dataprocpb.CreateBatchRequest{
Parent: fmt.Sprintf("projects/%s/locations/%s", source.GetProject(), source.GetLocation()),
Batch: batch,
}
op, err := client.CreateBatch(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to create batch: %w", err)
}
meta, err := op.Metadata()
if err != nil {
return nil, fmt.Errorf("failed to get create batch op metadata: %w", err)
}
jsonBytes, err := protojson.Marshal(meta)
if err != nil {
return nil, fmt.Errorf("failed to marshal create batch op metadata to JSON: %w", err)
}
var result map[string]any
if err := json.Unmarshal(jsonBytes, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal create batch op metadata JSON: %w", err)
}
projectID, location, batchID, err := common.ExtractBatchDetails(meta.GetBatch())
if err != nil {
return nil, fmt.Errorf("error extracting batch details from name %q: %v", meta.GetBatch(), err)
}
consoleUrl := common.BatchConsoleURL(projectID, location, batchID)
logsUrl := common.BatchLogsURL(projectID, location, batchID, meta.GetCreateTime().AsTime(), time.Time{})
wrappedResult := map[string]any{
"opMetadata": meta,
"consoleUrl": consoleUrl,
"logsUrl": logsUrl,
}
return wrappedResult, nil
}
func (t *Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -19,7 +19,8 @@ import (
"fmt"
"strings"
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
longrunning "cloud.google.com/go/longrunning/autogen"
"cloud.google.com/go/longrunning/autogen/longrunningpb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -44,8 +45,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
GetBatchControllerClient() *dataproc.BatchControllerClient
CancelOperation(context.Context, string) (any, error)
GetOperationsClient(context.Context) (*longrunning.OperationsClient, error)
GetProject() string
GetLocation() string
}
type Config struct {
@@ -104,15 +106,32 @@ func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, par
if err != nil {
return nil, err
}
client, err := source.GetOperationsClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get operations client: %w", err)
}
paramMap := params.AsMap()
operation, ok := paramMap["operation"].(string)
if !ok {
return nil, fmt.Errorf("missing required parameter: operation")
}
if strings.Contains(operation, "/") {
return nil, fmt.Errorf("operation must be a short operation name without '/': %s", operation)
}
return source.CancelOperation(ctx, operation)
req := &longrunningpb.CancelOperationRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", source.GetProject(), source.GetLocation(), operation),
}
err = client.CancelOperation(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to cancel operation: %w", err)
}
return fmt.Sprintf("Cancelled [%s].", operation), nil
}
func (t *Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -16,15 +16,19 @@ package serverlesssparkgetbatch
import (
"context"
"encoding/json"
"fmt"
"strings"
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/protobuf/encoding/protojson"
)
const kind = "serverless-spark-get-batch"
@@ -45,7 +49,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
GetBatchControllerClient() *dataproc.BatchControllerClient
GetBatch(context.Context, string) (map[string]any, error)
GetProject() string
GetLocation() string
}
type Config struct {
@@ -104,15 +109,54 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
client := source.GetBatchControllerClient()
paramMap := params.AsMap()
name, ok := paramMap["name"].(string)
if !ok {
return nil, fmt.Errorf("missing required parameter: name")
}
if strings.Contains(name, "/") {
return nil, fmt.Errorf("name must be a short batch name without '/': %s", name)
}
return source.GetBatch(ctx, name)
req := &dataprocpb.GetBatchRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", source.GetProject(), source.GetLocation(), name),
}
batchPb, err := client.GetBatch(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to get batch: %w", err)
}
jsonBytes, err := protojson.Marshal(batchPb)
if err != nil {
return nil, fmt.Errorf("failed to marshal batch to JSON: %w", err)
}
var result map[string]any
if err := json.Unmarshal(jsonBytes, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal batch JSON: %w", err)
}
consoleUrl, err := common.BatchConsoleURLFromProto(batchPb)
if err != nil {
return nil, fmt.Errorf("error generating console url: %v", err)
}
logsUrl, err := common.BatchLogsURLFromProto(batchPb)
if err != nil {
return nil, fmt.Errorf("error generating logs url: %v", err)
}
wrappedResult := map[string]any{
"consoleUrl": consoleUrl,
"logsUrl": logsUrl,
"batch": result,
}
return wrappedResult, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
return parameters.ParseParams(t.Parameters, data, claims)

View File

@@ -17,13 +17,17 @@ package serverlesssparklistbatches
import (
"context"
"fmt"
"time"
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/iterator"
)
const kind = "serverless-spark-list-batches"
@@ -44,7 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
GetBatchControllerClient() *dataproc.BatchControllerClient
ListBatches(context.Context, *int, string, string) (any, error)
GetProject() string
GetLocation() string
}
type Config struct {
@@ -99,24 +104,95 @@ type Tool struct {
Parameters parameters.Parameters
}
// ListBatchesResponse is the response from the list batches API.
type ListBatchesResponse struct {
Batches []Batch `json:"batches"`
NextPageToken string `json:"nextPageToken"`
}
// Batch represents a single batch job.
type Batch struct {
Name string `json:"name"`
UUID string `json:"uuid"`
State string `json:"state"`
Creator string `json:"creator"`
CreateTime string `json:"createTime"`
Operation string `json:"operation"`
ConsoleURL string `json:"consoleUrl"`
LogsURL string `json:"logsUrl"`
}
// Invoke executes the tool's operation.
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramMap := params.AsMap()
var pageSize *int
if ps, ok := paramMap["pageSize"]; ok && ps != nil {
pageSizeV := ps.(int)
if pageSizeV <= 0 {
return nil, fmt.Errorf("pageSize must be positive: %d", pageSizeV)
}
pageSize = &pageSizeV
client := source.GetBatchControllerClient()
parent := fmt.Sprintf("projects/%s/locations/%s", source.GetProject(), source.GetLocation())
req := &dataprocpb.ListBatchesRequest{
Parent: parent,
OrderBy: "create_time desc",
}
pt, _ := paramMap["pageToken"].(string)
filter, _ := paramMap["filter"].(string)
return source.ListBatches(ctx, pageSize, pt, filter)
paramMap := params.AsMap()
if ps, ok := paramMap["pageSize"]; ok && ps != nil {
req.PageSize = int32(ps.(int))
if (req.PageSize) <= 0 {
return nil, fmt.Errorf("pageSize must be positive: %d", req.PageSize)
}
}
if pt, ok := paramMap["pageToken"]; ok && pt != nil {
req.PageToken = pt.(string)
}
if filter, ok := paramMap["filter"]; ok && filter != nil {
req.Filter = filter.(string)
}
it := client.ListBatches(ctx, req)
pager := iterator.NewPager(it, int(req.PageSize), req.PageToken)
var batchPbs []*dataprocpb.Batch
nextPageToken, err := pager.NextPage(&batchPbs)
if err != nil {
return nil, fmt.Errorf("failed to list batches: %w", err)
}
batches, err := ToBatches(batchPbs)
if err != nil {
return nil, err
}
return ListBatchesResponse{Batches: batches, NextPageToken: nextPageToken}, nil
}
// ToBatches converts a slice of protobuf Batch messages to a slice of Batch structs.
func ToBatches(batchPbs []*dataprocpb.Batch) ([]Batch, error) {
batches := make([]Batch, 0, len(batchPbs))
for _, batchPb := range batchPbs {
consoleUrl, err := common.BatchConsoleURLFromProto(batchPb)
if err != nil {
return nil, fmt.Errorf("error generating console url: %v", err)
}
logsUrl, err := common.BatchLogsURLFromProto(batchPb)
if err != nil {
return nil, fmt.Errorf("error generating logs url: %v", err)
}
batch := Batch{
Name: batchPb.Name,
UUID: batchPb.Uuid,
State: batchPb.State.Enum().String(),
Creator: batchPb.Creator,
CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339),
Operation: batchPb.Operation,
ConsoleURL: consoleUrl,
LogsURL: logsUrl,
}
batches = append(batches, batch)
}
return batches, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -214,12 +214,15 @@ func AddPostgresPrebuiltConfig(t *testing.T, config map[string]any) map[string]a
PostgresListDatabaseStatsToolKind = "postgres-list-database-stats"
PostgresListRolesToolKind = "postgres-list-roles"
PostgresListStoredProcedureToolKind = "postgres-list-stored-procedure"
)
tools, ok := config["tools"].(map[string]any)
if !ok {
t.Fatalf("unable to get tools from config")
}
tools["list_tables"] = map[string]any{
"kind": PostgresListTablesToolKind,
"source": "my-instance",
@@ -943,6 +946,8 @@ func TestCloudSQLMySQL_IPTypeParsingFromYAML(t *testing.T) {
// Finds and drops all tables in a postgres database.
func CleanupPostgresTables(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
// fmt.println("")
t.Logf("in cleanupPostgrestTables");
query := `
SELECT table_name FROM information_schema.tables
WHERE table_schema = 'public' AND table_type = 'BASE TABLE';`
@@ -964,14 +969,31 @@ func CleanupPostgresTables(t *testing.T, ctx context.Context, pool *pgxpool.Pool
}
if len(tablesToDrop) == 0 {
t.Logf("No tables to drop in 'public' schema")
return
}
t.Logf("Tables to drop in 'public' schema: %s", strings.Join(tablesToDrop, ", "))
dropQuery := fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE;", strings.Join(tablesToDrop, ", "))
if _, err := pool.Exec(ctx, dropQuery); err != nil {
t.Fatalf("Failed to drop all tables in 'public' schema: %v", err)
}
t.Logf("Dropped tables in 'public' schema: %s", strings.Join(tablesToDrop, ", "))
// // 1. Drop the entire public schema (this kills tables, views, types, etc.)
// dropSchema := "DROP SCHEMA public CASCADE;"
// // 2. Recreate the empty public schema
// createSchema := "CREATE SCHEMA public;"
// // 3. Grant permissions back (Postgres default)
// grantPublic := "GRANT ALL ON SCHEMA public TO public;"
// _, err := pool.Exec(ctx, dropSchema + createSchema + grantPublic)
// if err != nil {
// t.Fatalf("Failed to nuclear-wipe the public schema: %v", err)
// }
}
// Finds and drops all tables in a mysql database.

View File

@@ -33,8 +33,8 @@ import (
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches"
"github.com/googleapis/genai-toolbox/tests"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
@@ -676,7 +676,7 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct
filter string
pageSize int
numPages int
want []serverlessspark.Batch
want []serverlesssparklistbatches.Batch
}{
{name: "one page", pageSize: 2, numPages: 1, want: batch2},
{name: "two pages", pageSize: 1, numPages: 2, want: batch2},
@@ -701,7 +701,7 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
var actual []serverlessspark.Batch
var actual []serverlesssparklistbatches.Batch
var pageToken string
for i := 0; i < tc.numPages; i++ {
request := map[string]any{
@@ -733,7 +733,7 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct
t.Fatalf("unable to find result in response body")
}
var listResponse serverlessspark.ListBatchesResponse
var listResponse serverlesssparklistbatches.ListBatchesResponse
if err := json.Unmarshal([]byte(result), &listResponse); err != nil {
t.Fatalf("error unmarshalling result: %s", err)
}
@@ -759,7 +759,7 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct
}
}
func listBatchesRpc(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, filter string, n int, exact bool) []serverlessspark.Batch {
func listBatchesRpc(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, filter string, n int, exact bool) []serverlesssparklistbatches.Batch {
parent := fmt.Sprintf("projects/%s/locations/%s", serverlessSparkProject, serverlessSparkLocation)
req := &dataprocpb.ListBatchesRequest{
Parent: parent,
@@ -783,7 +783,7 @@ func listBatchesRpc(t *testing.T, client *dataproc.BatchControllerClient, ctx co
if !exact && (len(batchPbs) == 0 || len(batchPbs) > n) {
t.Fatalf("expected between 1 and %d batches, got %d", n, len(batchPbs))
}
batches, err := serverlessspark.ToBatches(batchPbs)
batches, err := serverlesssparklistbatches.ToBatches(batchPbs)
if err != nil {
t.Fatalf("failed to convert batches to JSON: %v", err)
}