mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-15 02:18:10 -05:00
Compare commits
15 Commits
v0.25.0
...
config-pre
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8c786786d5 | ||
|
|
94e19b62b6 | ||
|
|
8e0fb03483 | ||
|
|
2817dd1e5d | ||
|
|
68a218407e | ||
|
|
d69792d843 | ||
|
|
647b04d3a7 | ||
|
|
030df9766f | ||
|
|
5dbf207162 | ||
|
|
9c3720e31d | ||
|
|
3cd3c39d66 | ||
|
|
0691a6f715 | ||
|
|
467b96a23b | ||
|
|
4abf0c39e7 | ||
|
|
dd7b9de623 |
@@ -59,6 +59,13 @@ You can manually trigger the bot by commenting on your Pull Request:
|
||||
* `/gemini summary`: Posts a summary of the changes in the pull request.
|
||||
* `/gemini help`: Overview of the available commands
|
||||
|
||||
## Guidelines for Pull Requests
|
||||
|
||||
1. Please keep your PR small for more thorough review and easier updates. In case of regression, it also allows us to roll back a single feature instead of multiple ones.
|
||||
1. For non-trivial changes, consider opening an issue and discussing it with the code owners first.
|
||||
1. Provide a good PR description as a record of what change is being made and why it was made. Link to a GitHub issue if it exists.
|
||||
1. Make sure your code is thoroughly tested with unit tests and integration tests. Remember to clean up the test instances properly in your code to avoid memory leaks.
|
||||
|
||||
## Adding a New Database Source or Tool
|
||||
|
||||
Please create an
|
||||
@@ -85,11 +92,11 @@ implementation](https://github.com/googleapis/genai-toolbox/blob/main/internal/s
|
||||
`newdb.go`. Create a `Config` struct to include all the necessary parameters
|
||||
for connecting to the database (e.g., host, port, username, password, database
|
||||
name) and a `Source` struct to store necessary parameters for tools (e.g.,
|
||||
Name, Kind, connection object, additional config).
|
||||
Name, Type, connection object, additional config).
|
||||
* **Implement the
|
||||
[`SourceConfig`](https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/internal/sources/sources.go#L57)
|
||||
interface**. This interface requires two methods:
|
||||
* `SourceConfigKind() string`: Returns a unique string identifier for your
|
||||
* `SourceConfigType() string`: Returns a unique string identifier for your
|
||||
data source (e.g., `"newdb"`).
|
||||
* `Initialize(ctx context.Context, tracer trace.Tracer) (Source, error)`:
|
||||
Creates a new instance of your data source and establishes a connection to
|
||||
@@ -97,7 +104,7 @@ implementation](https://github.com/googleapis/genai-toolbox/blob/main/internal/s
|
||||
* **Implement the
|
||||
[`Source`](https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/internal/sources/sources.go#L63)
|
||||
interface**. This interface requires one method:
|
||||
* `SourceKind() string`: Returns the same string identifier as `SourceConfigKind()`.
|
||||
* `SourceType() string`: Returns the same string identifier as `SourceConfigType()`.
|
||||
* **Implement `init()`** to register the new Source.
|
||||
* **Implement Unit Tests** in a file named `newdb_test.go`.
|
||||
|
||||
@@ -110,6 +117,8 @@ implementation](https://github.com/googleapis/genai-toolbox/blob/main/internal/s
|
||||
We recommend looking at an [example tool
|
||||
implementation](https://github.com/googleapis/genai-toolbox/tree/main/internal/tools/postgres/postgressql).
|
||||
|
||||
Remember to keep your PRs small. For example, if you are contributing a new Source, only include one or two core Tools within the same PR, the rest of the Tools can come in subsequent PRs.
|
||||
|
||||
* **Create a new directory** under `internal/tools` for your tool type (e.g., `internal/tools/newdb/newdbtool`).
|
||||
* **Define a configuration struct** for your tool in a file named `newdbtool.go`.
|
||||
Create a `Config` struct and a `Tool` struct to store necessary parameters for
|
||||
@@ -117,7 +126,7 @@ tools.
|
||||
* **Implement the
|
||||
[`ToolConfig`](https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/internal/tools/tools.go#L61)
|
||||
interface**. This interface requires one method:
|
||||
* `ToolConfigKind() string`: Returns a unique string identifier for your tool
|
||||
* `ToolConfigType() string`: Returns a unique string identifier for your tool
|
||||
(e.g., `"newdb-tool"`).
|
||||
* `Initialize(sources map[string]Source) (Tool, error)`: Creates a new
|
||||
instance of your tool and validates that it can connect to the specified
|
||||
@@ -163,6 +172,8 @@ tools.
|
||||
parameters][temp-param-doc]. Only run this test if template
|
||||
parameters apply to your tool.
|
||||
|
||||
* **Add additional tests** for the tools that are not covered by the predefined tests. Every tool must be tested!
|
||||
|
||||
* **Add the new database to the integration test workflow** in
|
||||
[integration.cloudbuild.yaml](.ci/integration.cloudbuild.yaml).
|
||||
|
||||
@@ -232,7 +243,7 @@ resources.
|
||||
| style | Update src code, with only formatting and whitespace updates (e.g. code formatter or linter changes). |
|
||||
|
||||
Pull requests should always add scope whenever possible. The scope is
|
||||
formatted as `<scope-type>/<scope-kind>` (e.g., `sources/postgres`, or
|
||||
formatted as `<scope-resource>/<scope-type>` (e.g., `sources/postgres`, or
|
||||
`tools/mssql-sql`).
|
||||
|
||||
Ideally, **each PR covers only one scope**, if this is
|
||||
@@ -244,4 +255,4 @@ resources.
|
||||
* **PR Description:** PR description should **always** be included. It should
|
||||
include a concise description of the changes, it's impact, along with a
|
||||
summary of the solution. If the PR is related to a specific issue, the issue
|
||||
number should be mentioned in the PR description (e.g. `Fixes #1`).
|
||||
number should be mentioned in the PR description (e.g. `Fixes #1`).
|
||||
|
||||
17
DEVELOPER.md
17
DEVELOPER.md
@@ -379,6 +379,23 @@ to approve PRs for main. TeamSync is used to create this team from the MDB
|
||||
Group `toolbox-contributors`. Googlers who are developing for MCP-Toolbox
|
||||
but aren't part of the core team should join this group.
|
||||
|
||||
### Issue/PR Triage and SLO
|
||||
After an issue is created, maintainers will assign the following labels:
|
||||
* `Priority` (defaulted to P0)
|
||||
* `Type` (if applicable)
|
||||
* `Product` (if applicable)
|
||||
|
||||
All incoming issues and PRs will follow the following SLO:
|
||||
| Type | Priority | Objective |
|
||||
|-----------------|----------|------------------------------------------------------------------------|
|
||||
| Feature Request | P0 | Must respond within **5 days** |
|
||||
| Process | P0 | Must respond within **5 days** |
|
||||
| Bugs | P0 | Must respond within **5 days**, and resolve/closure within **14 days** |
|
||||
| Bugs | P1 | Must respond within **7 days**, and resolve/closure within **90 days** |
|
||||
| Bugs | P2 | Must respond within **30 days**
|
||||
|
||||
_Types that are not listed in the table do not adhere to any SLO._
|
||||
|
||||
### Releasing
|
||||
|
||||
Toolbox has two types of releases: versioned and continuous. It uses Google
|
||||
|
||||
@@ -272,7 +272,7 @@ To run Toolbox from binary:
|
||||
To run the server after pulling the [container image](#installing-the-server):
|
||||
|
||||
```sh
|
||||
export VERSION=0.11.0 # Use the version you pulled
|
||||
export VERSION=0.24.0 # Use the version you pulled
|
||||
docker run -p 5000:5000 \
|
||||
-v $(pwd)/tools.yaml:/app/tools.yaml \
|
||||
us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION \
|
||||
@@ -954,7 +954,7 @@ For more details on configuring different types of sources, see the
|
||||
### Tools
|
||||
|
||||
The `tools` section of a `tools.yaml` define the actions an agent can take: what
|
||||
kind of tool it is, which source(s) it affects, what parameters it uses, etc.
|
||||
type of tool it is, which source(s) it affects, what parameters it uses, etc.
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
|
||||
99
cmd/root.go
99
cmd/root.go
@@ -15,6 +15,7 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
@@ -92,6 +93,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudmonitoring"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcloneinstance"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatebackup"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatedatabase"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreateusers"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlgetinstances"
|
||||
@@ -424,6 +426,103 @@ func parseEnv(input string) (string, error) {
|
||||
return output, err
|
||||
}
|
||||
|
||||
func convertToolsFile(ctx context.Context, raw []byte) ([]byte, error) {
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: add embeddingmodels when available
|
||||
keysToCheck := []string{"sources", "authServices", "authSources", "tools", "prompts", "toolsets"}
|
||||
var input yaml.MapSlice
|
||||
decoder := yaml.NewDecoder(bytes.NewReader(raw), yaml.UseOrderedMap())
|
||||
if err := decoder.Decode(&input); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert raw MapSlice to a helper map for quick lookup
|
||||
// while keeping the values as MapSlices to preserve internal order
|
||||
lookup := make(map[string]yaml.MapSlice)
|
||||
for _, item := range input {
|
||||
key := item.Key.(string)
|
||||
if slice, ok := item.Value.(yaml.MapSlice); ok {
|
||||
// convert authSources to authServices
|
||||
if key == "authSources" {
|
||||
logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` instead")
|
||||
key = "authServices"
|
||||
}
|
||||
// works even if lookup[key] is nil
|
||||
lookup[key] = append(lookup[key], slice...)
|
||||
} else {
|
||||
// toolsfile is already v2
|
||||
if key == "kind" {
|
||||
return raw, nil
|
||||
}
|
||||
return nil, fmt.Errorf("'%s' is not a map", key)
|
||||
}
|
||||
}
|
||||
// convert to tools file v2
|
||||
var buf bytes.Buffer
|
||||
encoder := yaml.NewEncoder(&buf)
|
||||
for _, kind := range keysToCheck {
|
||||
data, exists := lookup[kind]
|
||||
if !exists {
|
||||
// if this is skipped for all keys, the tools file is in v2
|
||||
continue
|
||||
}
|
||||
// Transform each entry
|
||||
for _, entry := range data {
|
||||
entryName := entry.Key.(string)
|
||||
entryBody := ProcessValue(entry.Value, kind == "toolsets")
|
||||
|
||||
transformed := yaml.MapSlice{
|
||||
{Key: "kind", Value: kind},
|
||||
{Key: "name", Value: entryName},
|
||||
}
|
||||
|
||||
// Merge the transformed body into our result
|
||||
if bodySlice, ok := entryBody.(yaml.MapSlice); ok {
|
||||
transformed = append(transformed, bodySlice...)
|
||||
} else {
|
||||
return nil, fmt.Errorf("unable to convert entryBody to MapSlice")
|
||||
}
|
||||
|
||||
if err := encoder.Encode(transformed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// ProcessValue recursively looks for MapSlices to rename 'kind' -> 'type'
|
||||
func ProcessValue(v any, isToolset bool) any {
|
||||
switch val := v.(type) {
|
||||
case yaml.MapSlice:
|
||||
for i := range val {
|
||||
// Perform renaming
|
||||
if val[i].Key == "kind" {
|
||||
val[i].Key = "type"
|
||||
}
|
||||
// Recursive call for nested values (e.g., nested objects or lists)
|
||||
val[i].Value = ProcessValue(val[i].Value, false)
|
||||
}
|
||||
return val
|
||||
case []any:
|
||||
// Process lists: If it's a toolset top-level list, wrap it.
|
||||
if isToolset {
|
||||
return yaml.MapSlice{{Key: "tools", Value: val}}
|
||||
}
|
||||
// Otherwise, recurse into list items (to catch nested objects)
|
||||
for i := range val {
|
||||
val[i] = ProcessValue(val[i], false)
|
||||
}
|
||||
return val
|
||||
default:
|
||||
return val
|
||||
}
|
||||
}
|
||||
|
||||
// parseToolsFile parses the provided yaml into appropriate configs.
|
||||
func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
|
||||
var toolsFile ToolsFile
|
||||
|
||||
273
cmd/root_test.go
273
cmd/root_test.go
@@ -23,12 +23,14 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/auth/google"
|
||||
@@ -494,6 +496,235 @@ func TestDefaultLogLevel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToolsFile(t *testing.T) {
|
||||
ctx, cancelCtx := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancelCtx()
|
||||
pr, pw := io.Pipe()
|
||||
defer pw.Close()
|
||||
defer pr.Close()
|
||||
|
||||
logger, err := log.NewStdLogger(pw, pw, "DEBUG")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup logger %s", err)
|
||||
}
|
||||
ctx = util.WithLogger(ctx, logger)
|
||||
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want string
|
||||
isErr bool
|
||||
errStr string
|
||||
}{
|
||||
{
|
||||
desc: "basic convert",
|
||||
in: `
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
tools:
|
||||
example_tool:
|
||||
kind: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
toolsets:
|
||||
example_toolset:
|
||||
- example_tool`,
|
||||
want: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
---
|
||||
kind: tools
|
||||
name: example_tool
|
||||
type: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
---
|
||||
kind: toolsets
|
||||
name: example_toolset
|
||||
tools:
|
||||
- example_tool`,
|
||||
},
|
||||
{
|
||||
desc: "rearrange resource order",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
toolsets:
|
||||
example_toolset:
|
||||
- example_tool`,
|
||||
want: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
---
|
||||
kind: tools
|
||||
name: example_tool
|
||||
type: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
---
|
||||
kind: toolsets
|
||||
name: example_toolset
|
||||
tools:
|
||||
- example_tool`,
|
||||
},
|
||||
{
|
||||
desc: "no convertion needed",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
---
|
||||
kind: tools
|
||||
name: example_tool
|
||||
type: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
---
|
||||
kind: toolsets
|
||||
name: example_toolset
|
||||
tools:
|
||||
- example_tool`,
|
||||
want: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
---
|
||||
kind: tools
|
||||
name: example_tool
|
||||
type: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
---
|
||||
kind: toolsets
|
||||
name: example_toolset
|
||||
tools:
|
||||
- example_tool`,
|
||||
},
|
||||
{
|
||||
desc: "invalid source",
|
||||
in: `sources: invalid`,
|
||||
isErr: true,
|
||||
errStr: "'sources' is not a map",
|
||||
},
|
||||
{
|
||||
desc: "invalid toolset",
|
||||
in: `toolsets: invalid`,
|
||||
isErr: true,
|
||||
errStr: "'toolsets' is not a map",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
output, err := convertToolsFile(ctx, []byte(tc.in))
|
||||
if tc.isErr {
|
||||
if err == nil {
|
||||
t.Fatalf("missing error: %s", tc.errStr)
|
||||
}
|
||||
if err.Error() != tc.errStr {
|
||||
t.Fatalf("invalid error string: got %s, want %s", err, tc.errStr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
// ensures that the order is correct
|
||||
var doc1, doc2 yaml.MapSlice
|
||||
if err := yaml.Unmarshal(output, &doc1); err != nil {
|
||||
t.Fatalf("unable to unmarshal output: %s", string(output))
|
||||
}
|
||||
if err := yaml.Unmarshal([]byte(tc.want), &doc2); err != nil {
|
||||
t.Fatalf("unable to unmarshal output: %s", tc.want)
|
||||
}
|
||||
if !reflect.DeepEqual(doc1, doc2) {
|
||||
t.Fatalf("incorrect output: got %s, want %s", doc1, doc2)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolFile(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
@@ -535,7 +766,7 @@ func TestParseToolFile(t *testing.T) {
|
||||
Sources: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpgsrc.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: cloudsqlpgsrc.SourceKind,
|
||||
Type: cloudsqlpgsrc.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -548,7 +779,7 @@ func TestParseToolFile(t *testing.T) {
|
||||
Tools: server.ToolConfigs{
|
||||
"example_tool": postgressql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "postgres-sql",
|
||||
Type: "postgres-sql",
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
@@ -689,7 +920,7 @@ func TestParseToolFileWithAuth(t *testing.T) {
|
||||
Sources: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpgsrc.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: cloudsqlpgsrc.SourceKind,
|
||||
Type: cloudsqlpgsrc.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -702,19 +933,19 @@ func TestParseToolFileWithAuth(t *testing.T) {
|
||||
AuthServices: server.AuthServiceConfigs{
|
||||
"my-google-service": google.Config{
|
||||
Name: "my-google-service",
|
||||
Kind: google.AuthServiceKind,
|
||||
Type: google.AuthServiceType,
|
||||
ClientID: "my-client-id",
|
||||
},
|
||||
"other-google-service": google.Config{
|
||||
Name: "other-google-service",
|
||||
Kind: google.AuthServiceKind,
|
||||
Type: google.AuthServiceType,
|
||||
ClientID: "other-client-id",
|
||||
},
|
||||
},
|
||||
Tools: server.ToolConfigs{
|
||||
"example_tool": postgressql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "postgres-sql",
|
||||
Type: "postgres-sql",
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
@@ -789,7 +1020,7 @@ func TestParseToolFileWithAuth(t *testing.T) {
|
||||
Sources: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpgsrc.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: cloudsqlpgsrc.SourceKind,
|
||||
Type: cloudsqlpgsrc.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -802,19 +1033,19 @@ func TestParseToolFileWithAuth(t *testing.T) {
|
||||
AuthSources: server.AuthServiceConfigs{
|
||||
"my-google-service": google.Config{
|
||||
Name: "my-google-service",
|
||||
Kind: google.AuthServiceKind,
|
||||
Type: google.AuthServiceType,
|
||||
ClientID: "my-client-id",
|
||||
},
|
||||
"other-google-service": google.Config{
|
||||
Name: "other-google-service",
|
||||
Kind: google.AuthServiceKind,
|
||||
Type: google.AuthServiceType,
|
||||
ClientID: "other-client-id",
|
||||
},
|
||||
},
|
||||
Tools: server.ToolConfigs{
|
||||
"example_tool": postgressql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "postgres-sql",
|
||||
Type: "postgres-sql",
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
@@ -891,7 +1122,7 @@ func TestParseToolFileWithAuth(t *testing.T) {
|
||||
Sources: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpgsrc.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: cloudsqlpgsrc.SourceKind,
|
||||
Type: cloudsqlpgsrc.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -904,19 +1135,19 @@ func TestParseToolFileWithAuth(t *testing.T) {
|
||||
AuthServices: server.AuthServiceConfigs{
|
||||
"my-google-service": google.Config{
|
||||
Name: "my-google-service",
|
||||
Kind: google.AuthServiceKind,
|
||||
Type: google.AuthServiceType,
|
||||
ClientID: "my-client-id",
|
||||
},
|
||||
"other-google-service": google.Config{
|
||||
Name: "other-google-service",
|
||||
Kind: google.AuthServiceKind,
|
||||
Type: google.AuthServiceType,
|
||||
ClientID: "other-client-id",
|
||||
},
|
||||
},
|
||||
Tools: server.ToolConfigs{
|
||||
"example_tool": postgressql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "postgres-sql",
|
||||
Type: "postgres-sql",
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
@@ -1062,7 +1293,7 @@ func TestEnvVarReplacement(t *testing.T) {
|
||||
Sources: server.SourceConfigs{
|
||||
"my-http-instance": httpsrc.Config{
|
||||
Name: "my-http-instance",
|
||||
Kind: httpsrc.SourceKind,
|
||||
Type: httpsrc.SourceType,
|
||||
BaseURL: "http://test_server/",
|
||||
Timeout: "10s",
|
||||
DefaultHeaders: map[string]string{"Authorization": "ACTUAL_HEADER"},
|
||||
@@ -1072,19 +1303,19 @@ func TestEnvVarReplacement(t *testing.T) {
|
||||
AuthServices: server.AuthServiceConfigs{
|
||||
"my-google-service": google.Config{
|
||||
Name: "my-google-service",
|
||||
Kind: google.AuthServiceKind,
|
||||
Type: google.AuthServiceType,
|
||||
ClientID: "ACTUAL_CLIENT_ID",
|
||||
},
|
||||
"other-google-service": google.Config{
|
||||
Name: "other-google-service",
|
||||
Kind: google.AuthServiceKind,
|
||||
Type: google.AuthServiceType,
|
||||
ClientID: "ACTUAL_CLIENT_ID_2",
|
||||
},
|
||||
},
|
||||
Tools: server.ToolConfigs{
|
||||
"example_tool": http.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "http",
|
||||
Type: "http",
|
||||
Source: "my-instance",
|
||||
Method: "GET",
|
||||
Path: "search?name=alice&pet=cat",
|
||||
@@ -1493,7 +1724,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"cloud_sql_postgres_admin_tools": tools.ToolsetConfig{
|
||||
Name: "cloud_sql_postgres_admin_tools",
|
||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance"},
|
||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance", "create_backup"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1503,7 +1734,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"cloud_sql_mysql_admin_tools": tools.ToolsetConfig{
|
||||
Name: "cloud_sql_mysql_admin_tools",
|
||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance"},
|
||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1513,7 +1744,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"cloud_sql_mssql_admin_tools": tools.ToolsetConfig{
|
||||
Name: "cloud_sql_mssql_admin_tools",
|
||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance"},
|
||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
0
cmd/test.db
Normal file
0
cmd/test.db
Normal file
@@ -48,6 +48,7 @@ instance, database and users:
|
||||
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
|
||||
* All `viewer` tools
|
||||
* `create_database`
|
||||
* `create_backup`
|
||||
* `roles/cloudsql.admin`: Provides full control over all resources.
|
||||
* All `editor` and `viewer` tools
|
||||
* `create_instance`
|
||||
@@ -299,6 +300,7 @@ instances and interacting with your database:
|
||||
* **create_user**: Creates a new user in a Cloud SQL instance.
|
||||
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
||||
* **clone_instance**: Creates a clone of an existing Cloud SQL for SQL Server instance.
|
||||
* **create_backup**: Creates a backup on a Cloud SQL instance.
|
||||
|
||||
{{< notice note >}}
|
||||
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
||||
|
||||
@@ -48,6 +48,7 @@ database and users:
|
||||
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
|
||||
* All `viewer` tools
|
||||
* `create_database`
|
||||
* `create_backup`
|
||||
* `roles/cloudsql.admin`: Provides full control over all resources.
|
||||
* All `editor` and `viewer` tools
|
||||
* `create_instance`
|
||||
@@ -299,6 +300,7 @@ instances and interacting with your database:
|
||||
* **create_user**: Creates a new user in a Cloud SQL instance.
|
||||
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
||||
* **clone_instance**: Creates a clone of an existing Cloud SQL for MySQL instance.
|
||||
* **create_backup**: Creates a backup on a Cloud SQL instance.
|
||||
|
||||
{{< notice note >}}
|
||||
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
||||
|
||||
@@ -48,6 +48,7 @@ instance, database and users:
|
||||
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
|
||||
* All `viewer` tools
|
||||
* `create_database`
|
||||
* `create_backup`
|
||||
* `roles/cloudsql.admin`: Provides full control over all resources.
|
||||
* All `editor` and `viewer` tools
|
||||
* `create_instance`
|
||||
@@ -299,6 +300,7 @@ instances and interacting with your database:
|
||||
* **create_user**: Creates a new user in a Cloud SQL instance.
|
||||
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
||||
* **clone_instance**: Creates a clone of an existing Cloud SQL for PostgreSQL instance.
|
||||
* **create_backup**: Creates a backup on a Cloud SQL instance.
|
||||
|
||||
{{< notice note >}}
|
||||
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
||||
|
||||
@@ -187,6 +187,7 @@ See [Usage Examples](../reference/cli.md#examples).
|
||||
manage existing resources.
|
||||
* All `viewer` tools
|
||||
* `create_database`
|
||||
* `create_backup`
|
||||
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
|
||||
all resources.
|
||||
* All `editor` and `viewer` tools
|
||||
@@ -203,6 +204,7 @@ See [Usage Examples](../reference/cli.md#examples).
|
||||
* `create_user`: Creates a new user in a Cloud SQL instance.
|
||||
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
||||
* `clone_instance`: Creates a clone for an existing Cloud SQL for MySQL instance.
|
||||
* `create_backup`: Creates a backup on a Cloud SQL instance.
|
||||
|
||||
## Cloud SQL for PostgreSQL
|
||||
|
||||
@@ -275,6 +277,7 @@ See [Usage Examples](../reference/cli.md#examples).
|
||||
manage existing resources.
|
||||
* All `viewer` tools
|
||||
* `create_database`
|
||||
* `create_backup`
|
||||
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
|
||||
all resources.
|
||||
* All `editor` and `viewer` tools
|
||||
@@ -290,6 +293,7 @@ See [Usage Examples](../reference/cli.md#examples).
|
||||
* `create_user`: Creates a new user in a Cloud SQL instance.
|
||||
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
||||
* `clone_instance`: Creates a clone for an existing Cloud SQL for PostgreSQL instance.
|
||||
* `create_backup`: Creates a backup on a Cloud SQL instance.
|
||||
|
||||
## Cloud SQL for SQL Server
|
||||
|
||||
@@ -336,6 +340,7 @@ See [Usage Examples](../reference/cli.md#examples).
|
||||
manage existing resources.
|
||||
* All `viewer` tools
|
||||
* `create_database`
|
||||
* `create_backup`
|
||||
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
|
||||
all resources.
|
||||
* All `editor` and `viewer` tools
|
||||
@@ -351,6 +356,7 @@ See [Usage Examples](../reference/cli.md#examples).
|
||||
* `create_user`: Creates a new user in a Cloud SQL instance.
|
||||
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
||||
* `clone_instance`: Creates a clone for an existing Cloud SQL for SQL Server instance.
|
||||
* `create_backup`: Creates a backup on a Cloud SQL instance.
|
||||
|
||||
## Dataplex
|
||||
|
||||
|
||||
@@ -134,6 +134,7 @@ sources:
|
||||
# scopes: # Optional: List of OAuth scopes to request.
|
||||
# - "https://www.googleapis.com/auth/bigquery"
|
||||
# - "https://www.googleapis.com/auth/drive.readonly"
|
||||
# maxQueryResultRows: 50 # Optional: Limits the number of rows returned by queries. Defaults to 50.
|
||||
```
|
||||
|
||||
Initialize a BigQuery source that uses the client's access token:
|
||||
@@ -153,6 +154,7 @@ sources:
|
||||
# scopes: # Optional: List of OAuth scopes to request.
|
||||
# - "https://www.googleapis.com/auth/bigquery"
|
||||
# - "https://www.googleapis.com/auth/drive.readonly"
|
||||
# maxQueryResultRows: 50 # Optional: Limits the number of rows returned by queries. Defaults to 50.
|
||||
```
|
||||
|
||||
## Reference
|
||||
@@ -167,3 +169,4 @@ sources:
|
||||
| useClientOAuth | bool | false | If true, forwards the client's OAuth access token from the "Authorization" header to downstream queries. **Note:** This cannot be used with `writeMode: protected`. |
|
||||
| scopes | []string | false | A list of OAuth 2.0 scopes to use for the credentials. If not provided, default scopes are used. |
|
||||
| impersonateServiceAccount | string | false | Service account email to impersonate when making BigQuery and Dataplex API calls. The authenticated principal must have the `roles/iam.serviceAccountTokenCreator` role on the target service account. [Learn More](https://cloud.google.com/iam/docs/service-account-impersonation) |
|
||||
| maxQueryResultRows | int | false | The maximum number of rows to return from a query. Defaults to 50. |
|
||||
|
||||
@@ -91,8 +91,8 @@ visible to the LLM.
|
||||
https://cloud.google.com/alloydb/docs/parameterized-secure-views-overview
|
||||
|
||||
{{< notice tip >}} Make sure to enable the `parameterized_views` extension
|
||||
before running this tool. You can do so by running this command in the AlloyDB
|
||||
studio:
|
||||
to utilize PSV feature (`nlConfigParameters`) with this tool. You can do so by
|
||||
running this command in the AlloyDB studio:
|
||||
|
||||
```sql
|
||||
CREATE EXTENSION IF NOT EXISTS parameterized_views;
|
||||
|
||||
45
docs/en/resources/tools/cloudsql/cloudsqlcreatebackup.md
Normal file
45
docs/en/resources/tools/cloudsql/cloudsqlcreatebackup.md
Normal file
@@ -0,0 +1,45 @@
|
||||
---
|
||||
title: cloud-sql-create-backup
|
||||
type: docs
|
||||
weight: 10
|
||||
description: "Creates a backup on a Cloud SQL instance."
|
||||
---
|
||||
|
||||
The `cloud-sql-create-backup` tool creates an on-demand backup on a Cloud SQL instance using the Cloud SQL Admin API.
|
||||
|
||||
{{< notice info dd>}}
|
||||
This tool uses a `source` of kind `cloud-sql-admin`.
|
||||
{{< /notice >}}
|
||||
|
||||
## Examples
|
||||
|
||||
Basic backup creation (current state)
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
backup-creation-basic:
|
||||
kind: cloud-sql-create-backup
|
||||
source: cloud-sql-admin-source
|
||||
description: "Creates a backup on the given Cloud SQL instance."
|
||||
```
|
||||
## Reference
|
||||
### Tool Configuration
|
||||
| **field** | **type** | **required** | **description** |
|
||||
| -------------- | :------: | :----------: | ------------------------------------------------------------- |
|
||||
| kind | string | true | Must be "cloud-sql-create-backup". |
|
||||
| source | string | true | The name of the `cloud-sql-admin` source to use. |
|
||||
| description | string | false | A description of the tool. |
|
||||
|
||||
### Tool Inputs
|
||||
|
||||
| **parameter** | **type** | **required** | **description** |
|
||||
| -------------------------- | :------: | :----------: | ------------------------------------------------------------------------------- |
|
||||
| project | string | true | The project ID. |
|
||||
| instance | string | true | The name of the instance to take a backup on. Does not include the project ID. |
|
||||
| location | string | false | (Optional) Location of the backup run. |
|
||||
| backup_description | string | false | (Optional) The description of this backup run. |
|
||||
|
||||
## See Also
|
||||
- [Cloud SQL Admin API documentation](https://cloud.google.com/sql/docs/mysql/admin-api)
|
||||
- [Toolbox Cloud SQL tools documentation](../cloudsql)
|
||||
- [Cloud SQL Backup API documentation](https://cloud.google.com/sql/docs/mysql/backup-recovery/backups)
|
||||
7
docs/en/samples/neo4j/_index.md
Normal file
7
docs/en/samples/neo4j/_index.md
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "Neo4j"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
How to get started with Toolbox using Neo4j.
|
||||
---
|
||||
141
docs/en/samples/neo4j/mcp_quickstart.md
Normal file
141
docs/en/samples/neo4j/mcp_quickstart.md
Normal file
@@ -0,0 +1,141 @@
|
||||
---
|
||||
title: "Quickstart (MCP with Neo4j)"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
How to get started running Toolbox with MCP Inspector and Neo4j as the source.
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
[Model Context Protocol](https://modelcontextprotocol.io) is an open protocol that standardizes how applications provide context to LLMs. Check out this page on how to [connect to Toolbox via MCP](../../how-to/connect_via_mcp.md).
|
||||
|
||||
|
||||
## Step 1: Set up your Neo4j Database and Data
|
||||
|
||||
In this section, you'll set up a database and populate it with sample data for a movies-related agent. This guide assumes you have a running Neo4j instance, either locally or in the cloud.
|
||||
|
||||
. **Populate the database with data.**
|
||||
To make this quickstart straightforward, we'll use the built-in Movies dataset available in Neo4j.
|
||||
|
||||
. In your Neo4j Browser, run the following command to create and populate the database:
|
||||
+
|
||||
```cypher
|
||||
:play movies
|
||||
````
|
||||
|
||||
. Follow the instructions to load the data. This will create a graph with `Movie`, `Person`, and `Actor` nodes and their relationships.
|
||||
|
||||
|
||||
## Step 2: Install and configure Toolbox
|
||||
|
||||
In this section, we will install the MCP Toolbox, configure our tools in a `tools.yaml` file, and then run the Toolbox server.
|
||||
|
||||
. **Install the Toolbox binary.**
|
||||
The simplest way to get started is to download the latest binary for your operating system.
|
||||
|
||||
. Download the latest version of Toolbox as a binary:
|
||||
\+
|
||||
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O [https://storage.googleapis.com/genai-toolbox/v0.16.0/$OS/toolbox](https://storage.googleapis.com/genai-toolbox/v0.16.0/$OS/toolbox)
|
||||
```
|
||||
|
||||
+
|
||||
. Make the binary executable:
|
||||
\+
|
||||
|
||||
```bash
|
||||
chmod +x toolbox
|
||||
```
|
||||
|
||||
. **Create the `tools.yaml` file.**
|
||||
This file defines your Neo4j source and the specific tools that will be exposed to your AI agent.
|
||||
\+
|
||||
{{\< notice tip \>}}
|
||||
Authentication for the Neo4j source uses standard username and password fields. For production use, it is highly recommended to use environment variables for sensitive information like passwords.
|
||||
{{\< /notice \>}}
|
||||
\+
|
||||
Write the following into a `tools.yaml` file:
|
||||
\+
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-neo4j-source:
|
||||
kind: neo4j
|
||||
uri: bolt://localhost:7687
|
||||
user: neo4j
|
||||
password: my-password # Replace with your actual password
|
||||
|
||||
tools:
|
||||
search-movies-by-actor:
|
||||
kind: neo4j-cypher
|
||||
source: my-neo4j-source
|
||||
description: "Searches for movies an actor has appeared in based on their name. Useful for questions like 'What movies has Tom Hanks been in?'"
|
||||
parameters:
|
||||
- name: actor_name
|
||||
type: string
|
||||
description: The full name of the actor to search for.
|
||||
statement: |
|
||||
MATCH (p:Person {name: $actor_name}) -[:ACTED_IN]-> (m:Movie)
|
||||
RETURN m.title AS title, m.year AS year, m.genre AS genre
|
||||
|
||||
get-actor-for-movie:
|
||||
kind: neo4j-cypher
|
||||
source: my-neo4j-source
|
||||
description: "Finds the actors who starred in a specific movie. Useful for questions like 'Who acted in Inception?'"
|
||||
parameters:
|
||||
- name: movie_title
|
||||
type: string
|
||||
description: The exact title of the movie.
|
||||
statement: |
|
||||
MATCH (p:Person) -[:ACTED_IN]-> (m:Movie {title: $movie_title})
|
||||
RETURN p.name AS actor
|
||||
```
|
||||
|
||||
. **Start the Toolbox server.**
|
||||
Run the Toolbox server, pointing to the `tools.yaml` file you created earlier.
|
||||
\+
|
||||
|
||||
```bash
|
||||
./toolbox --tools-file "tools.yaml"
|
||||
```
|
||||
|
||||
## Step 3: Connect to MCP Inspector
|
||||
|
||||
. **Run the MCP Inspector:**
|
||||
\+
|
||||
|
||||
```bash
|
||||
npx @modelcontextprotocol/inspector
|
||||
```
|
||||
|
||||
. Type `y` when it asks to install the inspector package.
|
||||
. It should show the following when the MCP Inspector is up and running (please take note of `<YOUR_SESSION_TOKEN>`):
|
||||
\+
|
||||
|
||||
```bash
|
||||
Starting MCP inspector...
|
||||
⚙️ Proxy server listening on localhost:6277
|
||||
🔑 Session token: <YOUR_SESSION_TOKEN>
|
||||
Use this token to authenticate requests or set DANGEROUSLY_OMIT_AUTH=true to disable auth
|
||||
|
||||
🚀 MCP Inspector is up and running at:
|
||||
http://localhost:6274/?MCP_PROXY_AUTH_TOKEN=<YOUR_SESSION_TOKEN>
|
||||
```
|
||||
|
||||
1. Open the above link in your browser.
|
||||
|
||||
1. For `Transport Type`, select `Streamable HTTP`.
|
||||
|
||||
1. For `URL`, type in `http://127.0.0.1:5000/mcp`.
|
||||
|
||||
1. For `Configuration` -\> `Proxy Session Token`, make sure `<YOUR_SESSION_TOKEN>` is present.
|
||||
|
||||
1. Click `Connect`.
|
||||
|
||||
1. Select `List Tools`, you will see a list of tools configured in `tools.yaml`.
|
||||
|
||||
1. Test out your tools here\!
|
||||
|
||||
@@ -21,13 +21,13 @@ import (
|
||||
|
||||
// AuthServiceConfig is the interface for configuring authentication services.
|
||||
type AuthServiceConfig interface {
|
||||
AuthServiceConfigKind() string
|
||||
AuthServiceConfigType() string
|
||||
Initialize() (AuthService, error)
|
||||
}
|
||||
|
||||
// AuthService is the interface for authentication services.
|
||||
type AuthService interface {
|
||||
AuthServiceKind() string
|
||||
AuthServiceType() string
|
||||
GetName() string
|
||||
GetClaimsFromHeader(context.Context, http.Header) (map[string]any, error)
|
||||
ToConfig() AuthServiceConfig
|
||||
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
"google.golang.org/api/idtoken"
|
||||
)
|
||||
|
||||
const AuthServiceKind string = "google"
|
||||
const AuthServiceType string = "google"
|
||||
|
||||
// validate interface
|
||||
var _ auth.AuthServiceConfig = Config{}
|
||||
@@ -31,13 +31,13 @@ var _ auth.AuthServiceConfig = Config{}
|
||||
// Auth service configuration
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
ClientID string `yaml:"clientId" validate:"required"`
|
||||
}
|
||||
|
||||
// Returns the auth service kind
|
||||
func (cfg Config) AuthServiceConfigKind() string {
|
||||
return AuthServiceKind
|
||||
// Returns the auth service type
|
||||
func (cfg Config) AuthServiceConfigType() string {
|
||||
return AuthServiceType
|
||||
}
|
||||
|
||||
// Initialize a Google auth service
|
||||
@@ -55,9 +55,9 @@ type AuthService struct {
|
||||
Config
|
||||
}
|
||||
|
||||
// Returns the auth service kind
|
||||
func (a AuthService) AuthServiceKind() string {
|
||||
return AuthServiceKind
|
||||
// Returns the auth service type
|
||||
func (a AuthService) AuthServiceType() string {
|
||||
return AuthServiceType
|
||||
}
|
||||
|
||||
func (a AuthService) ToConfig() auth.AuthServiceConfig {
|
||||
|
||||
@@ -22,12 +22,12 @@ import (
|
||||
|
||||
// EmbeddingModelConfig is the interface for configuring embedding models.
|
||||
type EmbeddingModelConfig interface {
|
||||
EmbeddingModelConfigKind() string
|
||||
EmbeddingModelConfigType() string
|
||||
Initialize(context.Context) (EmbeddingModel, error)
|
||||
}
|
||||
|
||||
type EmbeddingModel interface {
|
||||
EmbeddingModelKind() string
|
||||
EmbeddingModelType() string
|
||||
ToConfig() EmbeddingModelConfig
|
||||
EmbedParameters(context.Context, []string) ([][]float32, error)
|
||||
}
|
||||
|
||||
@@ -23,22 +23,22 @@ import (
|
||||
"google.golang.org/genai"
|
||||
)
|
||||
|
||||
const EmbeddingModelKind string = "gemini"
|
||||
const EmbeddingModelType string = "gemini"
|
||||
|
||||
// validate interface
|
||||
var _ embeddingmodels.EmbeddingModelConfig = Config{}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Model string `yaml:"model" validate:"required"`
|
||||
ApiKey string `yaml:"apiKey"`
|
||||
Dimension int32 `yaml:"dimension"`
|
||||
}
|
||||
|
||||
// Returns the embedding model kind
|
||||
func (cfg Config) EmbeddingModelConfigKind() string {
|
||||
return EmbeddingModelKind
|
||||
// Returns the embedding model type
|
||||
func (cfg Config) EmbeddingModelConfigType() string {
|
||||
return EmbeddingModelType
|
||||
}
|
||||
|
||||
// Initialize a Gemini embedding model
|
||||
@@ -69,9 +69,9 @@ type EmbeddingModel struct {
|
||||
Config
|
||||
}
|
||||
|
||||
// Returns the embedding model kind
|
||||
func (m EmbeddingModel) EmbeddingModelKind() string {
|
||||
return EmbeddingModelKind
|
||||
// Returns the embedding model type
|
||||
func (m EmbeddingModel) EmbeddingModelType() string {
|
||||
return EmbeddingModelType
|
||||
}
|
||||
|
||||
func (m EmbeddingModel) ToConfig() embeddingmodels.EmbeddingModelConfig {
|
||||
|
||||
@@ -42,7 +42,7 @@ func TestParseFromYamlGemini(t *testing.T) {
|
||||
want: map[string]embeddingmodels.EmbeddingModelConfig{
|
||||
"my-gemini-model": gemini.Config{
|
||||
Name: "my-gemini-model",
|
||||
Kind: gemini.EmbeddingModelKind,
|
||||
Type: gemini.EmbeddingModelType,
|
||||
Model: "text-embedding-004",
|
||||
},
|
||||
},
|
||||
@@ -60,7 +60,7 @@ func TestParseFromYamlGemini(t *testing.T) {
|
||||
want: map[string]embeddingmodels.EmbeddingModelConfig{
|
||||
"complex-gemini": gemini.Config{
|
||||
Name: "complex-gemini",
|
||||
Kind: gemini.EmbeddingModelKind,
|
||||
Type: gemini.EmbeddingModelType,
|
||||
Model: "text-embedding-004",
|
||||
ApiKey: "test-api-key",
|
||||
Dimension: 768,
|
||||
|
||||
@@ -19,6 +19,7 @@ sources:
|
||||
location: ${BIGQUERY_LOCATION:}
|
||||
useClientOAuth: ${BIGQUERY_USE_CLIENT_OAUTH:false}
|
||||
scopes: ${BIGQUERY_SCOPES:}
|
||||
maxQueryResultRows: ${BIGQUERY_MAX_QUERY_RESULT_ROWS:50}
|
||||
|
||||
tools:
|
||||
analyze_contribution:
|
||||
|
||||
@@ -43,6 +43,9 @@ tools:
|
||||
clone_instance:
|
||||
kind: cloud-sql-clone-instance
|
||||
source: cloud-sql-admin-source
|
||||
create_backup:
|
||||
kind: cloud-sql-create-backup
|
||||
source: cloud-sql-admin-source
|
||||
|
||||
toolsets:
|
||||
cloud_sql_mssql_admin_tools:
|
||||
@@ -54,3 +57,4 @@ toolsets:
|
||||
- create_user
|
||||
- wait_for_operation
|
||||
- clone_instance
|
||||
- create_backup
|
||||
|
||||
@@ -43,6 +43,9 @@ tools:
|
||||
clone_instance:
|
||||
kind: cloud-sql-clone-instance
|
||||
source: cloud-sql-admin-source
|
||||
create_backup:
|
||||
kind: cloud-sql-create-backup
|
||||
source: cloud-sql-admin-source
|
||||
|
||||
toolsets:
|
||||
cloud_sql_mysql_admin_tools:
|
||||
@@ -54,3 +57,4 @@ toolsets:
|
||||
- create_user
|
||||
- wait_for_operation
|
||||
- clone_instance
|
||||
- create_backup
|
||||
|
||||
@@ -46,6 +46,9 @@ tools:
|
||||
postgres_upgrade_precheck:
|
||||
kind: postgres-upgrade-precheck
|
||||
source: cloud-sql-admin-source
|
||||
create_backup:
|
||||
kind: cloud-sql-create-backup
|
||||
source: cloud-sql-admin-source
|
||||
|
||||
toolsets:
|
||||
cloud_sql_postgres_admin_tools:
|
||||
@@ -58,3 +61,4 @@ toolsets:
|
||||
- wait_for_operation
|
||||
- postgres_upgrade_precheck
|
||||
- clone_instance
|
||||
- create_backup
|
||||
|
||||
@@ -27,10 +27,10 @@ type Message = prompts.Message
|
||||
|
||||
const kind = "custom"
|
||||
|
||||
// init registers this prompt kind with the prompt framework.
|
||||
// init registers this prompt type with the prompt framework.
|
||||
func init() {
|
||||
if !prompts.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("prompt kind %q already registered", kind))
|
||||
panic(fmt.Sprintf("prompt type %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ type Config struct {
|
||||
var _ prompts.PromptConfig = Config{}
|
||||
var _ prompts.Prompt = Prompt{}
|
||||
|
||||
func (c Config) PromptConfigKind() string {
|
||||
func (c Config) PromptConfigType() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
|
||||
@@ -50,8 +50,8 @@ func TestConfig(t *testing.T) {
|
||||
if p == nil {
|
||||
t.Fatal("Initialize() returned a nil prompt")
|
||||
}
|
||||
if cfg.PromptConfigKind() != "custom" {
|
||||
t.Errorf("PromptConfigKind() = %q, want %q", cfg.PromptConfigKind(), "custom")
|
||||
if cfg.PromptConfigType() != "custom" {
|
||||
t.Errorf("PromptConfigType() = %q, want %q", cfg.PromptConfigType(), "custom")
|
||||
}
|
||||
|
||||
t.Run("Manifest", func(t *testing.T) {
|
||||
|
||||
@@ -52,7 +52,7 @@ func DecodeConfig(ctx context.Context, kind, name string, decoder *yaml.Decoder)
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil, fmt.Errorf("unknown prompt kind: %q", kind)
|
||||
return nil, fmt.Errorf("unknown prompt type: %q", kind)
|
||||
}
|
||||
|
||||
promptConfig, err := factory(ctx, name, decoder)
|
||||
@@ -63,7 +63,7 @@ func DecodeConfig(ctx context.Context, kind, name string, decoder *yaml.Decoder)
|
||||
}
|
||||
|
||||
type PromptConfig interface {
|
||||
PromptConfigKind() string
|
||||
PromptConfigType() string
|
||||
Initialize() (Prompt, error)
|
||||
}
|
||||
|
||||
|
||||
@@ -29,16 +29,16 @@ import (
|
||||
|
||||
type mockPromptConfig struct {
|
||||
name string
|
||||
kind string
|
||||
Type string
|
||||
}
|
||||
|
||||
func (m *mockPromptConfig) PromptConfigKind() string { return m.kind }
|
||||
func (m *mockPromptConfig) PromptConfigType() string { return m.Type }
|
||||
func (m *mockPromptConfig) Initialize() (prompts.Prompt, error) { return nil, nil }
|
||||
|
||||
var errMockFactory = errors.New("mock factory error")
|
||||
|
||||
func mockFactory(ctx context.Context, name string, decoder *yaml.Decoder) (prompts.PromptConfig, error) {
|
||||
return &mockPromptConfig{name: name, kind: "mockKind"}, nil
|
||||
return &mockPromptConfig{name: name, Type: "mockType"}, nil
|
||||
}
|
||||
|
||||
func mockErrorFactory(ctx context.Context, name string, decoder *yaml.Decoder) (prompts.PromptConfig, error) {
|
||||
@@ -50,7 +50,7 @@ func TestRegistry(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("RegisterAndDecodeSuccess", func(t *testing.T) {
|
||||
kind := "testKindSuccess"
|
||||
kind := "testTypeSuccess"
|
||||
if !prompts.Register(kind, mockFactory) {
|
||||
t.Fatal("expected registration to succeed")
|
||||
}
|
||||
@@ -69,19 +69,19 @@ func TestRegistry(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DecodeUnknownKind", func(t *testing.T) {
|
||||
t.Run("DecodeUnknownType", func(t *testing.T) {
|
||||
decoder := yaml.NewDecoder(strings.NewReader(""))
|
||||
_, err := prompts.DecodeConfig(ctx, "unregisteredKind", "testPrompt", decoder)
|
||||
_, err := prompts.DecodeConfig(ctx, "unregisteredType", "testPrompt", decoder)
|
||||
if err == nil {
|
||||
t.Fatal("expected an error for unknown kind, but got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unknown prompt kind") {
|
||||
t.Errorf("expected error to contain 'unknown prompt kind', but got: %v", err)
|
||||
if !strings.Contains(err.Error(), "unknown prompt type") {
|
||||
t.Errorf("expected error to contain 'unknown prompt type', but got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FactoryReturnsError", func(t *testing.T) {
|
||||
kind := "testKindError"
|
||||
kind := "testTypeError"
|
||||
if !prompts.Register(kind, mockErrorFactory) {
|
||||
t.Fatal("expected registration to succeed")
|
||||
}
|
||||
@@ -105,8 +105,8 @@ func TestRegistry(t *testing.T) {
|
||||
if config == nil {
|
||||
t.Fatal("expected a non-nil config for default kind")
|
||||
}
|
||||
if config.PromptConfigKind() != "custom" {
|
||||
t.Errorf("expected default kind to be 'custom', but got %q", config.PromptConfigKind())
|
||||
if config.PromptConfigType() != "custom" {
|
||||
t.Errorf("expected default kind to be 'custom', but got %q", config.PromptConfigType())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -198,7 +198,7 @@ func (c *AuthServiceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(i
|
||||
return fmt.Errorf("error creating decoder: %w", err)
|
||||
}
|
||||
switch kind {
|
||||
case google.AuthServiceKind:
|
||||
case google.AuthServiceType:
|
||||
actual := google.Config{Name: name}
|
||||
if err := dec.DecodeContext(ctx, &actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", kind, err)
|
||||
@@ -242,7 +242,7 @@ func (c *EmbeddingModelConfigs) UnmarshalYAML(ctx context.Context, unmarshal fun
|
||||
return fmt.Errorf("error creating decoder: %w", err)
|
||||
}
|
||||
switch kind {
|
||||
case gemini.EmbeddingModelKind:
|
||||
case gemini.EmbeddingModelType:
|
||||
actual := gemini.Config{Name: name}
|
||||
if err := dec.DecodeContext(ctx, &actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", kind, err)
|
||||
|
||||
@@ -32,7 +32,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
"example-source": &alloydbpg.Source{
|
||||
Config: alloydbpg.Config{
|
||||
Name: "example-alloydb-source",
|
||||
Kind: "alloydb-postgres",
|
||||
Type: "alloydb-postgres",
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -92,7 +92,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
"example-source2": &alloydbpg.Source{
|
||||
Config: alloydbpg.Config{
|
||||
Name: "example-alloydb-source2",
|
||||
Kind: "alloydb-postgres",
|
||||
Type: "alloydb-postgres",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -82,7 +82,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
childCtx, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/source/init",
|
||||
trace.WithAttributes(attribute.String("source_kind", sc.SourceConfigKind())),
|
||||
trace.WithAttributes(attribute.String("source_type", sc.SourceConfigType())),
|
||||
trace.WithAttributes(attribute.String("source_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
@@ -110,7 +110,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
_, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/auth/init",
|
||||
trace.WithAttributes(attribute.String("auth_kind", sc.AuthServiceConfigKind())),
|
||||
trace.WithAttributes(attribute.String("auth_type", sc.AuthServiceConfigType())),
|
||||
trace.WithAttributes(attribute.String("auth_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
@@ -138,7 +138,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
_, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/embeddingmodel/init",
|
||||
trace.WithAttributes(attribute.String("model_kind", ec.EmbeddingModelConfigKind())),
|
||||
trace.WithAttributes(attribute.String("model_type", ec.EmbeddingModelConfigType())),
|
||||
trace.WithAttributes(attribute.String("model_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
@@ -166,7 +166,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
_, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/tool/init",
|
||||
trace.WithAttributes(attribute.String("tool_kind", tc.ToolConfigKind())),
|
||||
trace.WithAttributes(attribute.String("tool_type", tc.ToolConfigType())),
|
||||
trace.WithAttributes(attribute.String("tool_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
@@ -235,7 +235,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
_, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/prompt/init",
|
||||
trace.WithAttributes(attribute.String("prompt_kind", pc.PromptConfigKind())),
|
||||
trace.WithAttributes(attribute.String("prompt_type", pc.PromptConfigType())),
|
||||
trace.WithAttributes(attribute.String("prompt_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
|
||||
@@ -141,7 +141,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
"example-source": &alloydbpg.Source{
|
||||
Config: alloydbpg.Config{
|
||||
Name: "example-alloydb-source",
|
||||
Kind: "alloydb-postgres",
|
||||
Type: "alloydb-postgres",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -32,14 +32,14 @@ import (
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
const SourceKind string = "alloydb-admin"
|
||||
const SourceType string = "alloydb-admin"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,13 +53,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
DefaultProject string `yaml:"defaultProject"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -106,8 +106,8 @@ type Source struct {
|
||||
Service *alloydbrestapi.Service
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
|
||||
@@ -41,7 +41,7 @@ func TestParseFromYamlAlloyDBAdmin(t *testing.T) {
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-alloydb-admin-instance": alloydbadmin.Config{
|
||||
Name: "my-alloydb-admin-instance",
|
||||
Kind: alloydbadmin.SourceKind,
|
||||
Type: alloydbadmin.SourceType,
|
||||
UseClientOAuth: false,
|
||||
},
|
||||
},
|
||||
@@ -57,7 +57,7 @@ func TestParseFromYamlAlloyDBAdmin(t *testing.T) {
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-alloydb-admin-instance": alloydbadmin.Config{
|
||||
Name: "my-alloydb-admin-instance",
|
||||
Kind: alloydbadmin.SourceKind,
|
||||
Type: alloydbadmin.SourceType,
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -29,14 +29,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "alloydb-postgres"
|
||||
const SourceType string = "alloydb-postgres"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Region string `yaml:"region" validate:"required"`
|
||||
Cluster string `yaml:"cluster" validate:"required"`
|
||||
@@ -61,8 +61,8 @@ type Config struct {
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -90,8 +90,8 @@ type Source struct {
|
||||
Pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -183,7 +183,7 @@ func getConnectionConfig(ctx context.Context, user, pass, dbname string) (string
|
||||
|
||||
func initAlloyDBPgConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, cluster, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
dsn, useIAM, err := getConnectionConfig(ctx, user, pass, dbname)
|
||||
|
||||
@@ -48,7 +48,7 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-pg-instance": alloydbpg.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: alloydbpg.SourceKind,
|
||||
Type: alloydbpg.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Cluster: "my-cluster",
|
||||
@@ -78,7 +78,7 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-pg-instance": alloydbpg.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: alloydbpg.SourceKind,
|
||||
Type: alloydbpg.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Cluster: "my-cluster",
|
||||
@@ -108,7 +108,7 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-pg-instance": alloydbpg.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: alloydbpg.SourceKind,
|
||||
Type: alloydbpg.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Cluster: "my-cluster",
|
||||
|
||||
@@ -41,7 +41,7 @@ import (
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
const SourceKind string = "bigquery"
|
||||
const SourceType string = "bigquery"
|
||||
|
||||
// CloudPlatformScope is a broad scope for Google Cloud Platform services.
|
||||
const CloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
|
||||
@@ -65,8 +65,8 @@ type BigQuerySessionProvider func(ctx context.Context) (*Session, error)
|
||||
type DataplexClientCreator func(tokenString string) (*dataplexapi.CatalogClient, error)
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,7 +81,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
type Config struct {
|
||||
// BigQuery configs
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Location string `yaml:"location"`
|
||||
WriteMode string `yaml:"writeMode"`
|
||||
@@ -89,6 +89,7 @@ type Config struct {
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
ImpersonateServiceAccount string `yaml:"impersonateServiceAccount"`
|
||||
Scopes StringOrStringSlice `yaml:"scopes"`
|
||||
MaxQueryResultRows int `yaml:"maxQueryResultRows"`
|
||||
}
|
||||
|
||||
// StringOrStringSlice is a custom type that can unmarshal both a single string
|
||||
@@ -118,15 +119,19 @@ func (s *StringOrStringSlice) UnmarshalYAML(unmarshal func(any) error) error {
|
||||
return fmt.Errorf("cannot unmarshal %T into StringOrStringSlice", v)
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
func (r Config) SourceConfigType() string {
|
||||
// Returns BigQuery source kind
|
||||
return SourceKind
|
||||
return SourceType
|
||||
}
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
if r.WriteMode == "" {
|
||||
r.WriteMode = WriteModeAllowed
|
||||
}
|
||||
|
||||
if r.MaxQueryResultRows == 0 {
|
||||
r.MaxQueryResultRows = 50
|
||||
}
|
||||
|
||||
if r.WriteMode == WriteModeProtected && r.UseClientOAuth {
|
||||
// The protected mode only allows write operations to the session's temporary datasets.
|
||||
// when using client OAuth, a new session is created every
|
||||
@@ -150,7 +155,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
Client: client,
|
||||
RestService: restService,
|
||||
TokenSource: tokenSource,
|
||||
MaxQueryResultRows: 50,
|
||||
MaxQueryResultRows: r.MaxQueryResultRows,
|
||||
ClientCreator: clientCreator,
|
||||
}
|
||||
|
||||
@@ -297,9 +302,9 @@ type Session struct {
|
||||
LastUsed time.Time
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
func (s *Source) SourceType() string {
|
||||
// Returns BigQuery Google SQL source kind
|
||||
return SourceKind
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -567,7 +572,7 @@ func (s *Source) RunSQL(ctx context.Context, bqClient *bigqueryapi.Client, state
|
||||
}
|
||||
|
||||
var out []any
|
||||
for {
|
||||
for s.MaxQueryResultRows <= 0 || len(out) < s.MaxQueryResultRows {
|
||||
var val []bigqueryapi.Value
|
||||
err = it.Next(&val)
|
||||
if err == iterator.Done {
|
||||
@@ -660,7 +665,7 @@ func initBigQueryConnection(
|
||||
impersonateServiceAccount string,
|
||||
scopes []string,
|
||||
) (*bigqueryapi.Client, *bigqueryrestapi.Service, oauth2.TokenSource, error) {
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
@@ -736,7 +741,7 @@ func initBigQueryConnectionWithOAuthToken(
|
||||
tokenString string,
|
||||
wantRestService bool,
|
||||
) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
// Construct token source
|
||||
token := &oauth2.Token{
|
||||
@@ -796,7 +801,7 @@ func initDataplexConnection(
|
||||
var clientCreator DataplexClientCreator
|
||||
var err error
|
||||
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
|
||||
@@ -21,9 +21,12 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"go.opentelemetry.io/otel/trace/noop"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
)
|
||||
|
||||
func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
@@ -43,7 +46,7 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Type: bigquery.SourceType,
|
||||
Project: "my-project",
|
||||
Location: "",
|
||||
WriteMode: "",
|
||||
@@ -63,7 +66,7 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Type: bigquery.SourceType,
|
||||
Project: "my-project",
|
||||
Location: "asia",
|
||||
WriteMode: "blocked",
|
||||
@@ -84,7 +87,7 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Type: bigquery.SourceType,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
UseClientOAuth: true,
|
||||
@@ -105,7 +108,7 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Type: bigquery.SourceType,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
AllowedDatasets: []string{"my_dataset"},
|
||||
@@ -125,7 +128,7 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Type: bigquery.SourceType,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
ImpersonateServiceAccount: "service-account@my-project.iam.gserviceaccount.com",
|
||||
@@ -147,13 +150,33 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Type: bigquery.SourceType,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
Scopes: []string{"https://www.googleapis.com/auth/bigquery", "https://www.googleapis.com/auth/cloud-platform"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with max query result rows example",
|
||||
in: `
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
maxQueryResultRows: 10
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Type: bigquery.SourceType,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
MaxQueryResultRows: 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
@@ -220,6 +243,59 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitialize_MaxQueryResultRows(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
ctx = util.WithUserAgent(ctx, "test-agent")
|
||||
tracer := noop.NewTracerProvider().Tracer("")
|
||||
|
||||
tcs := []struct {
|
||||
desc string
|
||||
cfg bigquery.Config
|
||||
want int
|
||||
}{
|
||||
{
|
||||
desc: "default value",
|
||||
cfg: bigquery.Config{
|
||||
Name: "test-default",
|
||||
Type: bigquery.SourceType,
|
||||
Project: "test-project",
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
want: 50,
|
||||
},
|
||||
{
|
||||
desc: "configured value",
|
||||
cfg: bigquery.Config{
|
||||
Name: "test-configured",
|
||||
Type: bigquery.SourceType,
|
||||
Project: "test-project",
|
||||
UseClientOAuth: true,
|
||||
MaxQueryResultRows: 100,
|
||||
},
|
||||
want: 100,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
src, err := tc.cfg.Initialize(ctx, tracer)
|
||||
if err != nil {
|
||||
t.Fatalf("Initialize failed: %v", err)
|
||||
}
|
||||
bqSrc, ok := src.(*bigquery.Source)
|
||||
if !ok {
|
||||
t.Fatalf("Expected *bigquery.Source, got %T", src)
|
||||
}
|
||||
if bqSrc.MaxQueryResultRows != tc.want {
|
||||
t.Errorf("MaxQueryResultRows = %d, want %d", bqSrc.MaxQueryResultRows, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -27,14 +27,14 @@ import (
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
const SourceKind string = "bigtable"
|
||||
const SourceType string = "bigtable"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,13 +48,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Instance string `yaml:"instance" validate:"required"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -77,8 +77,8 @@ type Source struct {
|
||||
Client *bigtable.Client
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -179,7 +179,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, configParam param
|
||||
|
||||
func initBigtableClient(ctx context.Context, tracer trace.Tracer, name, project, instance string) (*bigtable.Client, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
// Set up Bigtable data operations client.
|
||||
|
||||
@@ -43,7 +43,7 @@ func TestParseFromYamlBigtableDb(t *testing.T) {
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-bigtable-instance": bigtable.Config{
|
||||
Name: "my-bigtable-instance",
|
||||
Kind: bigtable.SourceKind,
|
||||
Type: bigtable.SourceType,
|
||||
Project: "my-project",
|
||||
Instance: "my-instance",
|
||||
},
|
||||
|
||||
@@ -25,11 +25,11 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "cassandra"
|
||||
const SourceType string = "cassandra"
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Hosts []string `yaml:"hosts" validate:"required"`
|
||||
Keyspace string `yaml:"keyspace"`
|
||||
ProtoVersion int `yaml:"protoVersion"`
|
||||
@@ -68,9 +68,9 @@ func (c Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// SourceConfigKind implements sources.SourceConfig.
|
||||
func (c Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
// SourceConfigType implements sources.SourceConfig.
|
||||
func (c Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
var _ sources.SourceConfig = Config{}
|
||||
@@ -89,9 +89,9 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
// SourceKind implements sources.Source.
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
// SourceType implements sources.Source.
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) RunSQL(ctx context.Context, statement string, params parameters.ParamValues) (any, error) {
|
||||
@@ -120,7 +120,7 @@ var _ sources.Source = &Source{}
|
||||
|
||||
func initCassandraSession(ctx context.Context, tracer trace.Tracer, c Config) (*gocql.Session, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, c.Name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, c.Name)
|
||||
defer span.End()
|
||||
|
||||
// Validate authentication configuration
|
||||
|
||||
@@ -43,7 +43,7 @@ func TestParseFromYamlCassandra(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-cassandra-instance": cassandra.Config{
|
||||
Name: "my-cassandra-instance",
|
||||
Kind: cassandra.SourceKind,
|
||||
Type: cassandra.SourceType,
|
||||
Hosts: []string{"my-host1", "my-host2"},
|
||||
Username: "",
|
||||
Password: "",
|
||||
@@ -77,7 +77,7 @@ func TestParseFromYamlCassandra(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-cassandra-instance": cassandra.Config{
|
||||
Name: "my-cassandra-instance",
|
||||
Kind: cassandra.SourceKind,
|
||||
Type: cassandra.SourceType,
|
||||
Hosts: []string{"my-host1", "my-host2"},
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
|
||||
@@ -28,14 +28,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "clickhouse"
|
||||
const SourceType string = "clickhouse"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
@@ -59,8 +59,8 @@ type Config struct {
|
||||
Secure bool `yaml:"secure"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -88,8 +88,8 @@ type Source struct {
|
||||
Pool *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -174,7 +174,7 @@ func validateConfig(protocol string) error {
|
||||
|
||||
func initClickHouseConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, protocol string, secure bool) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
if protocol == "" {
|
||||
|
||||
@@ -25,10 +25,10 @@ import (
|
||||
"go.opentelemetry.io/otel"
|
||||
)
|
||||
|
||||
func TestConfigSourceConfigKind(t *testing.T) {
|
||||
func TestConfigSourceConfigType(t *testing.T) {
|
||||
config := Config{}
|
||||
if config.SourceConfigKind() != SourceKind {
|
||||
t.Errorf("Expected %s, got %s", SourceKind, config.SourceConfigKind())
|
||||
if config.SourceConfigType() != SourceType {
|
||||
t.Errorf("Expected %s, got %s", SourceType, config.SourceConfigType())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func TestNewConfig(t *testing.T) {
|
||||
`,
|
||||
expected: Config{
|
||||
Name: "test-clickhouse",
|
||||
Kind: "clickhouse",
|
||||
Type: "clickhouse",
|
||||
Host: "localhost",
|
||||
Port: "8443",
|
||||
User: "default",
|
||||
@@ -75,7 +75,7 @@ func TestNewConfig(t *testing.T) {
|
||||
`,
|
||||
expected: Config{
|
||||
Name: "minimal-clickhouse",
|
||||
Kind: "clickhouse",
|
||||
Type: "clickhouse",
|
||||
Host: "127.0.0.1",
|
||||
Port: "8123",
|
||||
User: "testuser",
|
||||
@@ -100,7 +100,7 @@ func TestNewConfig(t *testing.T) {
|
||||
`,
|
||||
expected: Config{
|
||||
Name: "http-clickhouse",
|
||||
Kind: "clickhouse",
|
||||
Type: "clickhouse",
|
||||
Host: "clickhouse.example.com",
|
||||
Port: "8123",
|
||||
User: "analytics",
|
||||
@@ -125,7 +125,7 @@ func TestNewConfig(t *testing.T) {
|
||||
`,
|
||||
expected: Config{
|
||||
Name: "secure-clickhouse",
|
||||
Kind: "clickhouse",
|
||||
Type: "clickhouse",
|
||||
Host: "secure.clickhouse.io",
|
||||
Port: "8443",
|
||||
User: "secureuser",
|
||||
@@ -196,10 +196,10 @@ func TestNewConfigInvalidYAML(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSource_SourceKind(t *testing.T) {
|
||||
func TestSource_SourceType(t *testing.T) {
|
||||
source := &Source{}
|
||||
if source.SourceKind() != SourceKind {
|
||||
t.Errorf("Expected %s, got %s", SourceKind, source.SourceKind())
|
||||
if source.SourceType() != SourceType {
|
||||
t.Errorf("Expected %s, got %s", SourceType, source.SourceType())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -29,15 +29,15 @@ import (
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
const SourceKind string = "cloud-gemini-data-analytics"
|
||||
const SourceType string = "cloud-gemini-data-analytics"
|
||||
const Endpoint string = "https://geminidataanalytics.googleapis.com"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,13 +51,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
ProjectID string `yaml:"projectId" validate:"required"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
// Initialize initializes a Gemini Data Analytics Source instance.
|
||||
@@ -102,8 +102,8 @@ type Source struct {
|
||||
userAgent string
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
|
||||
@@ -47,7 +47,7 @@ func TestParseFromYamlCloudGDA(t *testing.T) {
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-gda-instance": cloudgda.Config{
|
||||
Name: "my-gda-instance",
|
||||
Kind: cloudgda.SourceKind,
|
||||
Type: cloudgda.SourceType,
|
||||
ProjectID: "test-project-id",
|
||||
UseClientOAuth: false,
|
||||
},
|
||||
@@ -65,7 +65,7 @@ func TestParseFromYamlCloudGDA(t *testing.T) {
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-gda-instance": cloudgda.Config{
|
||||
Name: "my-gda-instance",
|
||||
Kind: cloudgda.SourceKind,
|
||||
Type: cloudgda.SourceType,
|
||||
ProjectID: "another-project",
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
@@ -153,12 +153,12 @@ func TestInitialize(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
desc: "initialize with ADC",
|
||||
cfg: cloudgda.Config{Name: "test-gda", Kind: cloudgda.SourceKind, ProjectID: "test-proj"},
|
||||
cfg: cloudgda.Config{Name: "test-gda", Type: cloudgda.SourceType, ProjectID: "test-proj"},
|
||||
wantClientOAuth: false,
|
||||
},
|
||||
{
|
||||
desc: "initialize with client OAuth",
|
||||
cfg: cloudgda.Config{Name: "test-gda-oauth", Kind: cloudgda.SourceKind, ProjectID: "test-proj", UseClientOAuth: true},
|
||||
cfg: cloudgda.Config{Name: "test-gda-oauth", Type: cloudgda.SourceType, ProjectID: "test-proj", UseClientOAuth: true},
|
||||
wantClientOAuth: true,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -16,8 +16,12 @@ package cloudhealthcare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -30,7 +34,7 @@ import (
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
const SourceKind string = "cloud-healthcare"
|
||||
const SourceType string = "cloud-healthcare"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
@@ -38,8 +42,8 @@ var _ sources.SourceConfig = Config{}
|
||||
type HealthcareServiceCreator func(tokenString string) (*healthcare.Service, error)
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,7 +58,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
type Config struct {
|
||||
// Healthcare configs
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Region string `yaml:"region" validate:"required"`
|
||||
Dataset string `yaml:"dataset" validate:"required"`
|
||||
@@ -63,8 +67,8 @@ type Config struct {
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
}
|
||||
|
||||
func (c Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (c Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (c Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -140,7 +144,7 @@ func newHealthcareServiceCreator(ctx context.Context, tracer trace.Tracer, name
|
||||
}
|
||||
|
||||
func initHealthcareConnectionWithOAuthToken(ctx context.Context, tracer trace.Tracer, name string, userAgent string, tokenString string) (*healthcare.Service, error) {
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
// Construct token source
|
||||
token := &oauth2.Token{
|
||||
@@ -158,7 +162,7 @@ func initHealthcareConnectionWithOAuthToken(ctx context.Context, tracer trace.Tr
|
||||
}
|
||||
|
||||
func initHealthcareConnection(ctx context.Context, tracer trace.Tracer, name string) (*healthcare.Service, oauth2.TokenSource, error) {
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
cred, err := google.FindDefaultCredentials(ctx, healthcare.CloudHealthcareScope)
|
||||
@@ -190,8 +194,8 @@ type Source struct {
|
||||
allowedDICOMStores map[string]struct{}
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -255,3 +259,299 @@ func (s *Source) IsDICOMStoreAllowed(storeID string) bool {
|
||||
func (s *Source) UseClientAuthorization() bool {
|
||||
return s.UseClientOAuth
|
||||
}
|
||||
|
||||
func parseResults(resp *http.Response) (any, error) {
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
var jsonMap map[string]interface{}
|
||||
if err := json.Unmarshal(respBytes, &jsonMap); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
|
||||
}
|
||||
return jsonMap, nil
|
||||
}
|
||||
|
||||
func (s *Source) getService(tokenStr string) (*healthcare.Service, error) {
|
||||
svc := s.Service()
|
||||
var err error
|
||||
// Initialize new service if using user OAuth token
|
||||
if s.UseClientAuthorization() {
|
||||
svc, err = s.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
func (s *Source) FHIRFetchPage(ctx context.Context, url, tokenStr string) (any, error) {
|
||||
var httpClient *http.Client
|
||||
if s.UseClientAuthorization() {
|
||||
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tokenStr})
|
||||
httpClient = oauth2.NewClient(ctx, ts)
|
||||
} else {
|
||||
// The source.Service() object holds a client with the default credentials.
|
||||
// However, the client is not exported, so we have to create a new one.
|
||||
var err error
|
||||
httpClient, err = google.DefaultClient(ctx, healthcare.CloudHealthcareScope)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create default http client: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create http request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/fhir+json;charset=utf-8")
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get fhir page from %q: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return parseResults(resp)
|
||||
}
|
||||
|
||||
func (s *Source) FHIRPatientEverything(storeID, patientID, tokenStr string, opts []googleapi.CallOption) (any, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", s.Project(), s.Region(), s.DatasetID(), storeID, patientID)
|
||||
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.PatientEverything(name).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call patient everything for %q: %w", name, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return parseResults(resp)
|
||||
}
|
||||
|
||||
func (s *Source) FHIRPatientSearch(storeID, tokenStr string, opts []googleapi.CallOption) (any, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.SearchType(name, "Patient", &healthcare.SearchResourcesRequest{ResourceType: "Patient"}).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search patient resources: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return parseResults(resp)
|
||||
}
|
||||
|
||||
func (s *Source) GetDataset(tokenStr string) (*healthcare.Dataset, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
|
||||
dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
}
|
||||
return dataset, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetFHIRResource(storeID, resType, resID, tokenStr string) (any, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", s.Project(), s.Region(), s.DatasetID(), storeID, resType, resID)
|
||||
call := svc.Projects.Locations.Datasets.FhirStores.Fhir.Read(name)
|
||||
call.Header().Set("Content-Type", "application/fhir+json;charset=utf-8")
|
||||
resp, err := call.Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get fhir resource %q: %w", name, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return parseResults(resp)
|
||||
}
|
||||
|
||||
func (s *Source) GetDICOMStore(storeID, tokenStr string) (*healthcare.DicomStore, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.DicomStores.Get(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get DICOM store %q: %w", storeName, err)
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetFHIRStore(storeID, tokenStr string) (*healthcare.FhirStore, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.FhirStores.Get(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get FHIR store %q: %w", storeName, err)
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetDICOMStoreMetrics(storeID, tokenStr string) (*healthcare.DicomStoreMetrics, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.DicomStores.GetDICOMStoreMetrics(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metrics for DICOM store %q: %w", storeName, err)
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetFHIRStoreMetrics(storeID, tokenStr string) (*healthcare.FhirStoreMetrics, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.FhirStores.GetFHIRStoreMetrics(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metrics for FHIR store %q: %w", storeName, err)
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *Source) ListDICOMStores(tokenStr string) ([]*healthcare.DicomStore, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
|
||||
stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
}
|
||||
var filtered []*healthcare.DicomStore
|
||||
for _, store := range stores.DicomStores {
|
||||
if len(s.AllowedDICOMStores()) == 0 {
|
||||
filtered = append(filtered, store)
|
||||
continue
|
||||
}
|
||||
if len(store.Name) == 0 {
|
||||
continue
|
||||
}
|
||||
parts := strings.Split(store.Name, "/")
|
||||
if _, ok := s.AllowedDICOMStores()[parts[len(parts)-1]]; ok {
|
||||
filtered = append(filtered, store)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
func (s *Source) ListFHIRStores(tokenStr string) ([]*healthcare.FhirStore, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
|
||||
stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
}
|
||||
var filtered []*healthcare.FhirStore
|
||||
for _, store := range stores.FhirStores {
|
||||
if len(s.AllowedFHIRStores()) == 0 {
|
||||
filtered = append(filtered, store)
|
||||
continue
|
||||
}
|
||||
if len(store.Name) == 0 {
|
||||
continue
|
||||
}
|
||||
parts := strings.Split(store.Name, "/")
|
||||
if _, ok := s.AllowedFHIRStores()[parts[len(parts)-1]]; ok {
|
||||
filtered = append(filtered, store)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
func (s *Source) RetrieveRenderedDICOMInstance(storeID, study, series, sop string, frame int, tokenStr string) (any, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
dicomWebPath := fmt.Sprintf("studies/%s/series/%s/instances/%s/frames/%d/rendered", study, series, sop, frame)
|
||||
call := svc.Projects.Locations.Datasets.DicomStores.Studies.Series.Instances.Frames.RetrieveRendered(name, dicomWebPath)
|
||||
call.Header().Set("Accept", "image/jpeg")
|
||||
resp, err := call.Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to retrieve dicom instance rendered image: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("RetrieveRendered: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
base64String := base64.StdEncoding.EncodeToString(respBytes)
|
||||
return base64String, nil
|
||||
}
|
||||
|
||||
func (s *Source) SearchDICOM(toolType, storeID, dicomWebPath, tokenStr string, opts []googleapi.CallOption) (any, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
var resp *http.Response
|
||||
switch toolType {
|
||||
case "cloud-healthcare-search-dicom-instances":
|
||||
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...)
|
||||
case "cloud-healthcare-search-dicom-series":
|
||||
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForSeries(name, dicomWebPath).Do(opts...)
|
||||
case "cloud-healthcare-search-dicom-studies":
|
||||
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForStudies(name, dicomWebPath).Do(opts...)
|
||||
default:
|
||||
return nil, fmt.Errorf("incompatible tool type: %s", toolType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search dicom series: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
if len(respBytes) == 0 {
|
||||
return []interface{}{}, nil
|
||||
}
|
||||
var result []interface{}
|
||||
if err := json.Unmarshal(respBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ func TestParseFromYamlCloudHealthcare(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": cloudhealthcare.Config{
|
||||
Name: "my-instance",
|
||||
Kind: cloudhealthcare.SourceKind,
|
||||
Type: cloudhealthcare.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "us-central1",
|
||||
Dataset: "my-dataset",
|
||||
@@ -65,7 +65,7 @@ func TestParseFromYamlCloudHealthcare(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": cloudhealthcare.Config{
|
||||
Name: "my-instance",
|
||||
Kind: cloudhealthcare.SourceKind,
|
||||
Type: cloudhealthcare.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "us",
|
||||
Dataset: "my-dataset",
|
||||
@@ -91,7 +91,7 @@ func TestParseFromYamlCloudHealthcare(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": cloudhealthcare.Config{
|
||||
Name: "my-instance",
|
||||
Kind: cloudhealthcare.SourceKind,
|
||||
Type: cloudhealthcare.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "us",
|
||||
Dataset: "my-dataset",
|
||||
|
||||
@@ -29,14 +29,14 @@ import (
|
||||
monitoring "google.golang.org/api/monitoring/v3"
|
||||
)
|
||||
|
||||
const SourceKind string = "cloud-monitoring"
|
||||
const SourceType string = "cloud-monitoring"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,12 +50,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
// Initialize initializes a Cloud Monitoring Source instance.
|
||||
@@ -99,8 +99,8 @@ type Source struct {
|
||||
userAgent string
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
|
||||
@@ -42,7 +42,7 @@ func TestParseFromYamlCloudMonitoring(t *testing.T) {
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-cloud-monitoring-instance": cloudmonitoring.Config{
|
||||
Name: "my-cloud-monitoring-instance",
|
||||
Kind: cloudmonitoring.SourceKind,
|
||||
Type: cloudmonitoring.SourceType,
|
||||
UseClientOAuth: false,
|
||||
},
|
||||
},
|
||||
@@ -58,7 +58,7 @@ func TestParseFromYamlCloudMonitoring(t *testing.T) {
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-cloud-monitoring-instance": cloudmonitoring.Config{
|
||||
Name: "my-cloud-monitoring-instance",
|
||||
Kind: cloudmonitoring.SourceKind,
|
||||
Type: cloudmonitoring.SourceType,
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -34,7 +34,7 @@ import (
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
)
|
||||
|
||||
const SourceKind string = "cloud-sql-admin"
|
||||
const SourceType string = "cloud-sql-admin"
|
||||
|
||||
var targetLinkRegex = regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/databases/([^/]+)`)
|
||||
|
||||
@@ -42,8 +42,8 @@ var targetLinkRegex = regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/da
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,13 +57,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
DefaultProject string `yaml:"defaultProject"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
// Initialize initializes a CloudSQL Admin Source instance.
|
||||
@@ -110,8 +110,8 @@ type Source struct {
|
||||
Service *sqladmin.Service
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -352,6 +352,28 @@ func (s *Source) GetWaitForOperations(ctx context.Context, service *sqladmin.Ser
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *Source) InsertBackupRun(ctx context.Context, project, instance, location, backupDescription, accessToken string) (any, error) {
|
||||
backupRun := &sqladmin.BackupRun{}
|
||||
if location != "" {
|
||||
backupRun.Location = location
|
||||
}
|
||||
if backupDescription != "" {
|
||||
backupRun.Description = backupDescription
|
||||
}
|
||||
|
||||
service, err := s.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := service.BackupRuns.Insert(project, instance, backupRun).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating backup: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func generateCloudSQLConnectionMessage(ctx context.Context, source *Source, logger log.Logger, opResponse map[string]any, connectionMessageTemplate string) (string, bool) {
|
||||
operationType, ok := opResponse["operationType"].(string)
|
||||
if !ok || operationType != "CREATE_DATABASE" {
|
||||
|
||||
@@ -42,7 +42,7 @@ func TestParseFromYamlCloudSQLAdmin(t *testing.T) {
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-cloud-sql-admin-instance": cloudsqladmin.Config{
|
||||
Name: "my-cloud-sql-admin-instance",
|
||||
Kind: cloudsqladmin.SourceKind,
|
||||
Type: cloudsqladmin.SourceType,
|
||||
UseClientOAuth: false,
|
||||
},
|
||||
},
|
||||
@@ -58,7 +58,7 @@ func TestParseFromYamlCloudSQLAdmin(t *testing.T) {
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-cloud-sql-admin-instance": cloudsqladmin.Config{
|
||||
Name: "my-cloud-sql-admin-instance",
|
||||
Kind: cloudsqladmin.SourceKind,
|
||||
Type: cloudsqladmin.SourceType,
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -29,14 +29,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "cloud-sql-mssql"
|
||||
const SourceType string = "cloud-sql-mssql"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
type Config struct {
|
||||
// Cloud SQL MSSQL configs
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Region string `yaml:"region" validate:"required"`
|
||||
Instance string `yaml:"instance" validate:"required"`
|
||||
@@ -62,9 +62,9 @@ type Config struct {
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
func (r Config) SourceConfigType() string {
|
||||
// Returns Cloud SQL MSSQL source kind
|
||||
return SourceKind
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -94,9 +94,9 @@ type Source struct {
|
||||
Db *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
func (s *Source) SourceType() string {
|
||||
// Returns Cloud SQL MSSQL source kind
|
||||
return SourceKind
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -152,7 +152,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
|
||||
|
||||
func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
|
||||
@@ -46,7 +46,7 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": cloudsqlmssql.Config{
|
||||
Name: "my-instance",
|
||||
Kind: cloudsqlmssql.SourceKind,
|
||||
Type: cloudsqlmssql.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -74,7 +74,7 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": cloudsqlmssql.Config{
|
||||
Name: "my-instance",
|
||||
Kind: cloudsqlmssql.SourceKind,
|
||||
Type: cloudsqlmssql.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -102,7 +102,7 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": cloudsqlmssql.Config{
|
||||
Name: "my-instance",
|
||||
Kind: cloudsqlmssql.SourceKind,
|
||||
Type: cloudsqlmssql.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
|
||||
@@ -30,14 +30,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "cloud-sql-mysql"
|
||||
const SourceType string = "cloud-sql-mysql"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Region string `yaml:"region" validate:"required"`
|
||||
Instance string `yaml:"instance" validate:"required"`
|
||||
@@ -61,8 +61,8 @@ type Config struct {
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -90,8 +90,8 @@ type Source struct {
|
||||
Pool *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -184,7 +184,7 @@ func getConnectionConfig(ctx context.Context, user, pass string) (string, string
|
||||
|
||||
func initCloudSQLMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
// Configure the driver to connect to the database
|
||||
|
||||
@@ -46,7 +46,7 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-mysql-instance": cloudsqlmysql.Config{
|
||||
Name: "my-mysql-instance",
|
||||
Kind: cloudsqlmysql.SourceKind,
|
||||
Type: cloudsqlmysql.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -74,7 +74,7 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-mysql-instance": cloudsqlmysql.Config{
|
||||
Name: "my-mysql-instance",
|
||||
Kind: cloudsqlmysql.SourceKind,
|
||||
Type: cloudsqlmysql.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -102,7 +102,7 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-mysql-instance": cloudsqlmysql.Config{
|
||||
Name: "my-mysql-instance",
|
||||
Kind: cloudsqlmysql.SourceKind,
|
||||
Type: cloudsqlmysql.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -130,7 +130,7 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-mysql-instance": cloudsqlmysql.Config{
|
||||
Name: "my-mysql-instance",
|
||||
Kind: cloudsqlmysql.SourceKind,
|
||||
Type: cloudsqlmysql.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
|
||||
@@ -28,14 +28,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "cloud-sql-postgres"
|
||||
const SourceType string = "cloud-sql-postgres"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Region string `yaml:"region" validate:"required"`
|
||||
Instance string `yaml:"instance" validate:"required"`
|
||||
@@ -59,8 +59,8 @@ type Config struct {
|
||||
Password string `yaml:"password"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -88,8 +88,8 @@ type Source struct {
|
||||
Pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -162,7 +162,7 @@ func getConnectionConfig(ctx context.Context, user, pass, dbname string) (string
|
||||
|
||||
func initCloudSQLPgConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
// Configure the driver to connect to the database
|
||||
|
||||
@@ -46,7 +46,7 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpg.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: cloudsqlpg.SourceKind,
|
||||
Type: cloudsqlpg.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -74,7 +74,7 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpg.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: cloudsqlpg.SourceKind,
|
||||
Type: cloudsqlpg.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -102,7 +102,7 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpg.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: cloudsqlpg.SourceKind,
|
||||
Type: cloudsqlpg.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -130,7 +130,7 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpg.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: cloudsqlpg.SourceKind,
|
||||
Type: cloudsqlpg.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
|
||||
@@ -29,14 +29,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "couchbase"
|
||||
const SourceType string = "couchbase"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
ConnectionString string `yaml:"connectionString" validate:"required"`
|
||||
Bucket string `yaml:"bucket" validate:"required"`
|
||||
Scope string `yaml:"scope" validate:"required"`
|
||||
@@ -66,8 +66,8 @@ type Config struct {
|
||||
QueryScanConsistency uint `yaml:"queryScanConsistency"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -96,8 +96,8 @@ type Source struct {
|
||||
Scope *gocb.Scope
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
|
||||
@@ -45,7 +45,7 @@ func TestParseFromYamlCouchbase(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-couchbase-instance": couchbase.Config{
|
||||
Name: "my-couchbase-instance",
|
||||
Kind: couchbase.SourceKind,
|
||||
Type: couchbase.SourceType,
|
||||
ConnectionString: "localhost",
|
||||
Username: "Administrator",
|
||||
Password: "password",
|
||||
@@ -74,7 +74,7 @@ func TestParseFromYamlCouchbase(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-couchbase-instance": couchbase.Config{
|
||||
Name: "my-couchbase-instance",
|
||||
Kind: couchbase.SourceKind,
|
||||
Type: couchbase.SourceType,
|
||||
ConnectionString: "couchbases://localhost",
|
||||
Bucket: "travel-sample",
|
||||
Scope: "inventory",
|
||||
|
||||
@@ -19,6 +19,8 @@ import (
|
||||
"fmt"
|
||||
|
||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
||||
"cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"github.com/cenkalti/backoff/v5"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
@@ -27,14 +29,14 @@ import (
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
const SourceKind string = "dataplex"
|
||||
const SourceType string = "dataplex"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,13 +51,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
type Config struct {
|
||||
// Dataplex configs
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
func (r Config) SourceConfigType() string {
|
||||
// Returns Dataplex source kind
|
||||
return SourceKind
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -79,9 +81,9 @@ type Source struct {
|
||||
Client *dataplexapi.CatalogClient
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
func (s *Source) SourceType() string {
|
||||
// Returns Dataplex source kind
|
||||
return SourceKind
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -102,7 +104,7 @@ func initDataplexConnection(
|
||||
name string,
|
||||
project string,
|
||||
) (*dataplexapi.CatalogClient, error) {
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
cred, err := google.FindDefaultCredentials(ctx)
|
||||
@@ -121,3 +123,101 @@ func initDataplexConnection(
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *Source) LookupEntry(ctx context.Context, name string, view int, aspectTypes []string, entry string) (*dataplexpb.Entry, error) {
|
||||
viewMap := map[int]dataplexpb.EntryView{
|
||||
1: dataplexpb.EntryView_BASIC,
|
||||
2: dataplexpb.EntryView_FULL,
|
||||
3: dataplexpb.EntryView_CUSTOM,
|
||||
4: dataplexpb.EntryView_ALL,
|
||||
}
|
||||
req := &dataplexpb.LookupEntryRequest{
|
||||
Name: name,
|
||||
View: viewMap[view],
|
||||
AspectTypes: aspectTypes,
|
||||
Entry: entry,
|
||||
}
|
||||
result, err := s.CatalogClient().LookupEntry(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Source) searchRequest(ctx context.Context, query string, pageSize int, orderBy string) (*dataplexapi.SearchEntriesResultIterator, error) {
|
||||
// Create SearchEntriesRequest with the provided parameters
|
||||
req := &dataplexpb.SearchEntriesRequest{
|
||||
Query: query,
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", s.ProjectID()),
|
||||
PageSize: int32(pageSize),
|
||||
OrderBy: orderBy,
|
||||
SemanticSearch: true,
|
||||
}
|
||||
|
||||
// Perform the search using the CatalogClient - this will return an iterator
|
||||
it := s.CatalogClient().SearchEntries(ctx, req)
|
||||
if it == nil {
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", s.ProjectID())
|
||||
}
|
||||
return it, nil
|
||||
}
|
||||
|
||||
func (s *Source) SearchAspectTypes(ctx context.Context, query string, pageSize int, orderBy string) ([]*dataplexpb.AspectType, error) {
|
||||
q := query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype"
|
||||
it, err := s.searchRequest(ctx, q, pageSize, orderBy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Iterate through the search results and call GetAspectType for each result using the resource name
|
||||
var results []*dataplexpb.AspectType
|
||||
for {
|
||||
entry, err := it.Next()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
// Create an instance of exponential backoff with default values for retrying GetAspectType calls
|
||||
// InitialInterval, RandomizationFactor, Multiplier, MaxInterval = 500 ms, 0.5, 1.5, 60 s
|
||||
getAspectBackOff := backoff.NewExponentialBackOff()
|
||||
|
||||
resourceName := entry.DataplexEntry.GetEntrySource().Resource
|
||||
getAspectTypeReq := &dataplexpb.GetAspectTypeRequest{
|
||||
Name: resourceName,
|
||||
}
|
||||
|
||||
operation := func() (*dataplexpb.AspectType, error) {
|
||||
aspectType, err := s.CatalogClient().GetAspectType(ctx, getAspectTypeReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err)
|
||||
}
|
||||
return aspectType, nil
|
||||
}
|
||||
|
||||
// Retry the GetAspectType operation with exponential backoff
|
||||
aspectType, err := backoff.Retry(ctx, operation, backoff.WithBackOff(getAspectBackOff))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get aspect type after retries for entry %q: %w", resourceName, err)
|
||||
}
|
||||
|
||||
results = append(results, aspectType)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *Source) SearchEntries(ctx context.Context, query string, pageSize int, orderBy string) ([]*dataplexpb.SearchEntriesResult, error) {
|
||||
it, err := s.searchRequest(ctx, query, pageSize, orderBy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var results []*dataplexpb.SearchEntriesResult
|
||||
for {
|
||||
entry, err := it.Next()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
results = append(results, entry)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ func TestParseFromYamlDataplex(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": dataplex.Config{
|
||||
Name: "my-instance",
|
||||
Kind: dataplex.SourceKind,
|
||||
Type: dataplex.SourceType,
|
||||
Project: "my-project",
|
||||
},
|
||||
},
|
||||
|
||||
@@ -30,14 +30,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "dgraph"
|
||||
const SourceType string = "dgraph"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ type DgraphClient struct {
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
DgraphUrl string `yaml:"dgraphUrl" validate:"required"`
|
||||
User string `yaml:"user"`
|
||||
Password string `yaml:"password"`
|
||||
@@ -75,8 +75,8 @@ type Config struct {
|
||||
ApiKey string `yaml:"apiKey"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -103,8 +103,8 @@ type Source struct {
|
||||
Client *DgraphClient `yaml:"client"`
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -139,7 +139,7 @@ func (s *Source) RunSQL(statement string, params parameters.ParamValues, isQuery
|
||||
|
||||
func initDgraphHttpClient(ctx context.Context, tracer trace.Tracer, r Config) (*DgraphClient, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, r.Name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, r.Name)
|
||||
defer span.End()
|
||||
|
||||
if r.DgraphUrl == "" {
|
||||
|
||||
@@ -45,7 +45,7 @@ func TestParseFromYamlDgraph(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-dgraph-instance": dgraph.Config{
|
||||
Name: "my-dgraph-instance",
|
||||
Kind: dgraph.SourceKind,
|
||||
Type: dgraph.SourceType,
|
||||
DgraphUrl: "https://localhost:8080",
|
||||
ApiKey: "abc123",
|
||||
Password: "pass@123",
|
||||
@@ -65,7 +65,7 @@ func TestParseFromYamlDgraph(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-dgraph-instance": dgraph.Config{
|
||||
Name: "my-dgraph-instance",
|
||||
Kind: dgraph.SourceKind,
|
||||
Type: dgraph.SourceType,
|
||||
DgraphUrl: "https://localhost:8080",
|
||||
},
|
||||
},
|
||||
|
||||
@@ -30,14 +30,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "elasticsearch"
|
||||
const SourceType string = "elasticsearch"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,15 +51,15 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Addresses []string `yaml:"addresses" validate:"required"`
|
||||
Username string `yaml:"username"`
|
||||
Password string `yaml:"password"`
|
||||
APIKey string `yaml:"apikey"`
|
||||
}
|
||||
|
||||
func (c Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (c Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
type EsClient interface {
|
||||
@@ -139,9 +139,9 @@ func (c Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// SourceKind returns the kind string for this source.
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
// SourceType returns the kind string for this source.
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
|
||||
@@ -43,7 +43,7 @@ func TestParseFromYamlElasticsearch(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-es-instance": elasticsearch.Config{
|
||||
Name: "my-es-instance",
|
||||
Kind: elasticsearch.SourceKind,
|
||||
Type: elasticsearch.SourceType,
|
||||
Addresses: []string{"http://localhost:9200"},
|
||||
APIKey: "somekey",
|
||||
},
|
||||
|
||||
@@ -27,13 +27,13 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
)
|
||||
|
||||
const SourceKind string = "firebird"
|
||||
const SourceType string = "firebird"
|
||||
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
@@ -55,8 +55,8 @@ type Config struct {
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -84,8 +84,8 @@ type Source struct {
|
||||
Db *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -144,7 +144,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
|
||||
}
|
||||
|
||||
func initFirebirdConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string) (*sql.DB, error) {
|
||||
_, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
_, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
// urlExample := "user:password@host:port/path/to/database.fdb"
|
||||
|
||||
@@ -45,7 +45,7 @@ func TestParseFromYamlFirebird(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-fdb-instance": firebird.Config{
|
||||
Name: "my-fdb-instance",
|
||||
Kind: firebird.SourceKind,
|
||||
Type: firebird.SourceType,
|
||||
Host: "my-host",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
|
||||
@@ -16,7 +16,10 @@ package firestore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/firestore"
|
||||
"github.com/goccy/go-yaml"
|
||||
@@ -25,16 +28,17 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/api/firebaserules/v1"
|
||||
"google.golang.org/api/option"
|
||||
"google.golang.org/genproto/googleapis/type/latlng"
|
||||
)
|
||||
|
||||
const SourceKind string = "firestore"
|
||||
const SourceType string = "firestore"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,14 +53,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
type Config struct {
|
||||
// Firestore configs
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Database string `yaml:"database"` // Optional, defaults to "(default)"
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
func (r Config) SourceConfigType() string {
|
||||
// Returns Firestore source kind
|
||||
return SourceKind
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -88,9 +92,9 @@ type Source struct {
|
||||
RulesClient *firebaserules.Service
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
func (s *Source) SourceType() string {
|
||||
// Returns Firestore source kind
|
||||
return SourceKind
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -113,6 +117,476 @@ func (s *Source) GetDatabaseId() string {
|
||||
return s.Database
|
||||
}
|
||||
|
||||
// FirestoreValueToJSON converts a Firestore value to a simplified JSON representation
|
||||
// This removes type information and returns plain values
|
||||
func FirestoreValueToJSON(value any) any {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case time.Time:
|
||||
return v.Format(time.RFC3339Nano)
|
||||
case *latlng.LatLng:
|
||||
return map[string]any{
|
||||
"latitude": v.Latitude,
|
||||
"longitude": v.Longitude,
|
||||
}
|
||||
case []byte:
|
||||
return base64.StdEncoding.EncodeToString(v)
|
||||
case []any:
|
||||
result := make([]any, len(v))
|
||||
for i, item := range v {
|
||||
result[i] = FirestoreValueToJSON(item)
|
||||
}
|
||||
return result
|
||||
case map[string]any:
|
||||
result := make(map[string]any)
|
||||
for k, val := range v {
|
||||
result[k] = FirestoreValueToJSON(val)
|
||||
}
|
||||
return result
|
||||
case *firestore.DocumentRef:
|
||||
return v.Path
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// BuildQuery constructs the Firestore query from parameters
|
||||
func (s *Source) BuildQuery(collectionPath string, filter firestore.EntityFilter, selectFields []string, field string, direction firestore.Direction, limit int, analyzeQuery bool) (*firestore.Query, error) {
|
||||
collection := s.FirestoreClient().Collection(collectionPath)
|
||||
query := collection.Query
|
||||
|
||||
// Process and apply filters if template is provided
|
||||
if filter != nil {
|
||||
query = query.WhereEntity(filter)
|
||||
}
|
||||
if len(selectFields) > 0 {
|
||||
query = query.Select(selectFields...)
|
||||
}
|
||||
if field != "" {
|
||||
query = query.OrderBy(field, direction)
|
||||
}
|
||||
query = query.Limit(limit)
|
||||
|
||||
// Apply analyze options if enabled
|
||||
if analyzeQuery {
|
||||
query = query.WithRunOptions(firestore.ExplainOptions{
|
||||
Analyze: true,
|
||||
})
|
||||
}
|
||||
|
||||
return &query, nil
|
||||
}
|
||||
|
||||
// QueryResult represents a document result from the query
|
||||
type QueryResult struct {
|
||||
ID string `json:"id"`
|
||||
Path string `json:"path"`
|
||||
Data map[string]any `json:"data"`
|
||||
CreateTime any `json:"createTime,omitempty"`
|
||||
UpdateTime any `json:"updateTime,omitempty"`
|
||||
ReadTime any `json:"readTime,omitempty"`
|
||||
}
|
||||
|
||||
// QueryResponse represents the full response including optional metrics
|
||||
type QueryResponse struct {
|
||||
Documents []QueryResult `json:"documents"`
|
||||
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
|
||||
}
|
||||
|
||||
// ExecuteQuery runs the query and formats the results
|
||||
func (s *Source) ExecuteQuery(ctx context.Context, query *firestore.Query, analyzeQuery bool) (any, error) {
|
||||
docIterator := query.Documents(ctx)
|
||||
docs, err := docIterator.GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute query: %w", err)
|
||||
}
|
||||
// Convert results to structured format
|
||||
results := make([]QueryResult, len(docs))
|
||||
for i, doc := range docs {
|
||||
results[i] = QueryResult{
|
||||
ID: doc.Ref.ID,
|
||||
Path: doc.Ref.Path,
|
||||
Data: doc.Data(),
|
||||
CreateTime: doc.CreateTime,
|
||||
UpdateTime: doc.UpdateTime,
|
||||
ReadTime: doc.ReadTime,
|
||||
}
|
||||
}
|
||||
|
||||
// Return with explain metrics if requested
|
||||
if analyzeQuery {
|
||||
explainMetrics, err := getExplainMetrics(docIterator)
|
||||
if err == nil && explainMetrics != nil {
|
||||
response := QueryResponse{
|
||||
Documents: results,
|
||||
ExplainMetrics: explainMetrics,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// getExplainMetrics extracts explain metrics from the query iterator
|
||||
func getExplainMetrics(docIterator *firestore.DocumentIterator) (map[string]any, error) {
|
||||
explainMetrics, err := docIterator.ExplainMetrics()
|
||||
if err != nil || explainMetrics == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metricsData := make(map[string]any)
|
||||
|
||||
// Add plan summary if available
|
||||
if explainMetrics.PlanSummary != nil {
|
||||
planSummary := make(map[string]any)
|
||||
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
|
||||
metricsData["planSummary"] = planSummary
|
||||
}
|
||||
|
||||
// Add execution stats if available
|
||||
if explainMetrics.ExecutionStats != nil {
|
||||
executionStats := make(map[string]any)
|
||||
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
|
||||
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
|
||||
|
||||
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
|
||||
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
|
||||
}
|
||||
|
||||
if explainMetrics.ExecutionStats.DebugStats != nil {
|
||||
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
|
||||
}
|
||||
|
||||
metricsData["executionStats"] = executionStats
|
||||
}
|
||||
|
||||
return metricsData, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetDocuments(ctx context.Context, documentPaths []string) ([]any, error) {
|
||||
// Create document references from paths
|
||||
docRefs := make([]*firestore.DocumentRef, len(documentPaths))
|
||||
for i, path := range documentPaths {
|
||||
docRefs[i] = s.FirestoreClient().Doc(path)
|
||||
}
|
||||
|
||||
// Get all documents
|
||||
snapshots, err := s.FirestoreClient().GetAll(ctx, docRefs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get documents: %w", err)
|
||||
}
|
||||
|
||||
// Convert snapshots to response data
|
||||
results := make([]any, len(snapshots))
|
||||
for i, snapshot := range snapshots {
|
||||
docData := make(map[string]any)
|
||||
docData["path"] = documentPaths[i]
|
||||
docData["exists"] = snapshot.Exists()
|
||||
|
||||
if snapshot.Exists() {
|
||||
docData["data"] = snapshot.Data()
|
||||
docData["createTime"] = snapshot.CreateTime
|
||||
docData["updateTime"] = snapshot.UpdateTime
|
||||
docData["readTime"] = snapshot.ReadTime
|
||||
}
|
||||
|
||||
results[i] = docData
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *Source) AddDocuments(ctx context.Context, collectionPath string, documentData any, returnData bool) (map[string]any, error) {
|
||||
// Get the collection reference
|
||||
collection := s.FirestoreClient().Collection(collectionPath)
|
||||
|
||||
// Add the document to the collection
|
||||
docRef, writeResult, err := collection.Add(ctx, documentData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add document: %w", err)
|
||||
}
|
||||
// Build the response
|
||||
response := map[string]any{
|
||||
"documentPath": docRef.Path,
|
||||
"createTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
|
||||
}
|
||||
// Add document data if requested
|
||||
if returnData {
|
||||
// Fetch the updated document to return the current state
|
||||
snapshot, err := docRef.Get(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
|
||||
}
|
||||
// Convert the document data back to simple JSON format
|
||||
simplifiedData := FirestoreValueToJSON(snapshot.Data())
|
||||
response["documentData"] = simplifiedData
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *Source) UpdateDocument(ctx context.Context, documentPath string, updates []firestore.Update, documentData any, returnData bool) (map[string]any, error) {
|
||||
// Get the document reference
|
||||
docRef := s.FirestoreClient().Doc(documentPath)
|
||||
|
||||
// Prepare update data
|
||||
var writeResult *firestore.WriteResult
|
||||
var writeErr error
|
||||
|
||||
if len(updates) > 0 {
|
||||
writeResult, writeErr = docRef.Update(ctx, updates)
|
||||
} else {
|
||||
writeResult, writeErr = docRef.Set(ctx, documentData, firestore.MergeAll)
|
||||
}
|
||||
|
||||
if writeErr != nil {
|
||||
return nil, fmt.Errorf("failed to update document: %w", writeErr)
|
||||
}
|
||||
|
||||
// Build the response
|
||||
response := map[string]any{
|
||||
"documentPath": docRef.Path,
|
||||
"updateTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
|
||||
}
|
||||
|
||||
// Add document data if requested
|
||||
if returnData {
|
||||
// Fetch the updated document to return the current state
|
||||
snapshot, err := docRef.Get(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
|
||||
}
|
||||
// Convert the document data to simple JSON format
|
||||
simplifiedData := FirestoreValueToJSON(snapshot.Data())
|
||||
response["documentData"] = simplifiedData
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *Source) DeleteDocuments(ctx context.Context, documentPaths []string) ([]any, error) {
|
||||
// Create a BulkWriter to handle multiple deletions efficiently
|
||||
bulkWriter := s.FirestoreClient().BulkWriter(ctx)
|
||||
|
||||
// Keep track of jobs for each document
|
||||
jobs := make([]*firestore.BulkWriterJob, len(documentPaths))
|
||||
|
||||
// Add all delete operations to the BulkWriter
|
||||
for i, path := range documentPaths {
|
||||
docRef := s.FirestoreClient().Doc(path)
|
||||
job, err := bulkWriter.Delete(docRef)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add delete operation for document %q: %w", path, err)
|
||||
}
|
||||
jobs[i] = job
|
||||
}
|
||||
|
||||
// End the BulkWriter to execute all operations
|
||||
bulkWriter.End()
|
||||
|
||||
// Collect results
|
||||
results := make([]any, len(documentPaths))
|
||||
for i, job := range jobs {
|
||||
docData := make(map[string]any)
|
||||
docData["path"] = documentPaths[i]
|
||||
|
||||
// Wait for the job to complete and get the result
|
||||
_, err := job.Results()
|
||||
if err != nil {
|
||||
docData["success"] = false
|
||||
docData["error"] = err.Error()
|
||||
} else {
|
||||
docData["success"] = true
|
||||
}
|
||||
|
||||
results[i] = docData
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *Source) ListCollections(ctx context.Context, parentPath string) ([]any, error) {
|
||||
var collectionRefs []*firestore.CollectionRef
|
||||
var err error
|
||||
if parentPath != "" {
|
||||
// List subcollections of the specified document
|
||||
docRef := s.FirestoreClient().Doc(parentPath)
|
||||
collectionRefs, err = docRef.Collections(ctx).GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list subcollections of document %q: %w", parentPath, err)
|
||||
}
|
||||
} else {
|
||||
// List root collections
|
||||
collectionRefs, err = s.FirestoreClient().Collections(ctx).GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list root collections: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert collection references to response data
|
||||
results := make([]any, len(collectionRefs))
|
||||
for i, collRef := range collectionRefs {
|
||||
collData := make(map[string]any)
|
||||
collData["id"] = collRef.ID
|
||||
collData["path"] = collRef.Path
|
||||
|
||||
// If this is a subcollection, include parent information
|
||||
if collRef.Parent != nil {
|
||||
collData["parent"] = collRef.Parent.Path
|
||||
}
|
||||
results[i] = collData
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetRules(ctx context.Context) (any, error) {
|
||||
// Get the latest release for Firestore
|
||||
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", s.GetProjectId(), s.GetDatabaseId())
|
||||
release, err := s.FirebaseRulesClient().Projects.Releases.Get(releaseName).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get latest Firestore release: %w", err)
|
||||
}
|
||||
|
||||
if release.RulesetName == "" {
|
||||
return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", s.GetProjectId(), s.GetDatabaseId())
|
||||
}
|
||||
|
||||
// Get the ruleset content
|
||||
ruleset, err := s.FirebaseRulesClient().Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get ruleset content: %w", err)
|
||||
}
|
||||
|
||||
if ruleset.Source == nil || len(ruleset.Source.Files) == 0 {
|
||||
return nil, fmt.Errorf("no rules files found in ruleset")
|
||||
}
|
||||
|
||||
return ruleset, nil
|
||||
}
|
||||
|
||||
// SourcePosition represents the location of an issue in the source
|
||||
type SourcePosition struct {
|
||||
FileName string `json:"fileName,omitempty"`
|
||||
Line int64 `json:"line"` // 1-based
|
||||
Column int64 `json:"column"` // 1-based
|
||||
CurrentOffset int64 `json:"currentOffset"` // 0-based, inclusive start
|
||||
EndOffset int64 `json:"endOffset"` // 0-based, exclusive end
|
||||
}
|
||||
|
||||
// Issue represents a validation issue in the rules
|
||||
type Issue struct {
|
||||
SourcePosition SourcePosition `json:"sourcePosition"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
}
|
||||
|
||||
// ValidationResult represents the result of rules validation
|
||||
type ValidationResult struct {
|
||||
Valid bool `json:"valid"`
|
||||
IssueCount int `json:"issueCount"`
|
||||
FormattedIssues string `json:"formattedIssues,omitempty"`
|
||||
RawIssues []Issue `json:"rawIssues,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Source) ValidateRules(ctx context.Context, sourceParam string) (any, error) {
|
||||
// Create test request
|
||||
testRequest := &firebaserules.TestRulesetRequest{
|
||||
Source: &firebaserules.Source{
|
||||
Files: []*firebaserules.File{
|
||||
{
|
||||
Name: "firestore.rules",
|
||||
Content: sourceParam,
|
||||
},
|
||||
},
|
||||
},
|
||||
// We don't need test cases for validation only
|
||||
TestSuite: &firebaserules.TestSuite{
|
||||
TestCases: []*firebaserules.TestCase{},
|
||||
},
|
||||
}
|
||||
// Call the test API
|
||||
projectName := fmt.Sprintf("projects/%s", s.GetProjectId())
|
||||
response, err := s.FirebaseRulesClient().Projects.Test(projectName, testRequest).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate rules: %w", err)
|
||||
}
|
||||
|
||||
// Process the response
|
||||
if len(response.Issues) == 0 {
|
||||
return ValidationResult{
|
||||
Valid: true,
|
||||
IssueCount: 0,
|
||||
FormattedIssues: "✓ No errors detected. Rules are valid.",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Convert issues to our format
|
||||
issues := make([]Issue, len(response.Issues))
|
||||
for i, issue := range response.Issues {
|
||||
issues[i] = Issue{
|
||||
Description: issue.Description,
|
||||
Severity: issue.Severity,
|
||||
SourcePosition: SourcePosition{
|
||||
FileName: issue.SourcePosition.FileName,
|
||||
Line: issue.SourcePosition.Line,
|
||||
Column: issue.SourcePosition.Column,
|
||||
CurrentOffset: issue.SourcePosition.CurrentOffset,
|
||||
EndOffset: issue.SourcePosition.EndOffset,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Format issues
|
||||
sourceLines := strings.Split(sourceParam, "\n")
|
||||
var formattedOutput []string
|
||||
|
||||
formattedOutput = append(formattedOutput, fmt.Sprintf("Found %d issue(s) in rules source:\n", len(issues)))
|
||||
|
||||
for _, issue := range issues {
|
||||
issueString := fmt.Sprintf("%s: %s [Ln %d, Col %d]",
|
||||
issue.Severity,
|
||||
issue.Description,
|
||||
issue.SourcePosition.Line,
|
||||
issue.SourcePosition.Column)
|
||||
|
||||
if issue.SourcePosition.Line > 0 {
|
||||
lineIndex := int(issue.SourcePosition.Line - 1) // 0-based index
|
||||
if lineIndex >= 0 && lineIndex < len(sourceLines) {
|
||||
errorLine := sourceLines[lineIndex]
|
||||
issueString += fmt.Sprintf("\n```\n%s", errorLine)
|
||||
|
||||
// Add carets if we have column and offset information
|
||||
if issue.SourcePosition.Column > 0 &&
|
||||
issue.SourcePosition.CurrentOffset >= 0 &&
|
||||
issue.SourcePosition.EndOffset > issue.SourcePosition.CurrentOffset {
|
||||
|
||||
startColumn := int(issue.SourcePosition.Column - 1) // 0-based
|
||||
errorTokenLength := int(issue.SourcePosition.EndOffset - issue.SourcePosition.CurrentOffset)
|
||||
|
||||
if startColumn >= 0 && errorTokenLength > 0 && startColumn <= len(errorLine) {
|
||||
padding := strings.Repeat(" ", startColumn)
|
||||
carets := strings.Repeat("^", errorTokenLength)
|
||||
issueString += fmt.Sprintf("\n%s%s", padding, carets)
|
||||
}
|
||||
}
|
||||
issueString += "\n```"
|
||||
}
|
||||
}
|
||||
|
||||
formattedOutput = append(formattedOutput, issueString)
|
||||
}
|
||||
|
||||
formattedIssues := strings.Join(formattedOutput, "\n\n")
|
||||
|
||||
return ValidationResult{
|
||||
Valid: false,
|
||||
IssueCount: len(issues),
|
||||
FormattedIssues: formattedIssues,
|
||||
RawIssues: issues,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func initFirestoreConnection(
|
||||
ctx context.Context,
|
||||
tracer trace.Tracer,
|
||||
@@ -120,7 +594,7 @@ func initFirestoreConnection(
|
||||
project string,
|
||||
database string,
|
||||
) (*firestore.Client, error) {
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
|
||||
@@ -16,6 +16,7 @@ package firestore_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -41,7 +42,7 @@ func TestParseFromYamlFirestore(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-firestore": firestore.Config{
|
||||
Name: "my-firestore",
|
||||
Kind: firestore.SourceKind,
|
||||
Type: firestore.SourceType,
|
||||
Project: "my-project",
|
||||
Database: "",
|
||||
},
|
||||
@@ -59,7 +60,7 @@ func TestParseFromYamlFirestore(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-firestore": firestore.Config{
|
||||
Name: "my-firestore",
|
||||
Kind: firestore.SourceKind,
|
||||
Type: firestore.SourceType,
|
||||
Project: "my-project",
|
||||
Database: "my-database",
|
||||
},
|
||||
@@ -128,3 +129,37 @@ func TestFailParseFromYamlFirestore(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirestoreValueToJSON_RoundTrip(t *testing.T) {
|
||||
// Test round-trip conversion
|
||||
original := map[string]any{
|
||||
"name": "Test",
|
||||
"count": int64(42),
|
||||
"price": 19.99,
|
||||
"active": true,
|
||||
"tags": []any{"tag1", "tag2"},
|
||||
"metadata": map[string]any{
|
||||
"created": time.Now(),
|
||||
},
|
||||
"nullField": nil,
|
||||
}
|
||||
|
||||
// Convert to JSON representation
|
||||
jsonRepresentation := firestore.FirestoreValueToJSON(original)
|
||||
|
||||
// Verify types are simplified
|
||||
jsonMap, ok := jsonRepresentation.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected map, got %T", jsonRepresentation)
|
||||
}
|
||||
|
||||
// Time should be converted to string
|
||||
metadata, ok := jsonMap["metadata"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("metadata should be a map, got %T", jsonMap["metadata"])
|
||||
}
|
||||
_, ok = metadata["created"].(string)
|
||||
if !ok {
|
||||
t.Errorf("created should be a string, got %T", metadata["created"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,9 @@ package http
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
@@ -27,14 +29,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "http"
|
||||
const SourceType string = "http"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,7 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
BaseURL string `yaml:"baseUrl"`
|
||||
Timeout string `yaml:"timeout"`
|
||||
DefaultHeaders map[string]string `yaml:"headers"`
|
||||
@@ -56,8 +58,8 @@ type Config struct {
|
||||
DisableSslVerification bool `yaml:"disableSslVerification"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
// Initialize initializes an HTTP Source instance.
|
||||
@@ -120,8 +122,8 @@ type Source struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -143,3 +145,28 @@ func (s *Source) HttpQueryParams() map[string]string {
|
||||
func (s *Source) Client() *http.Client {
|
||||
return s.client
|
||||
}
|
||||
|
||||
func (s *Source) RunRequest(req *http.Request) (any, error) {
|
||||
// Make request and fetch response
|
||||
resp, err := s.Client().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making HTTP request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var body []byte
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var data any
|
||||
if err = json.Unmarshal(body, &data); err != nil {
|
||||
// if unable to unmarshal data, return result as string.
|
||||
return string(body), nil
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func TestParseFromYamlHttp(t *testing.T) {
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-http-instance": http.Config{
|
||||
Name: "my-http-instance",
|
||||
Kind: http.SourceKind,
|
||||
Type: http.SourceType,
|
||||
BaseURL: "http://test_server/",
|
||||
Timeout: "30s",
|
||||
DisableSslVerification: false,
|
||||
@@ -68,7 +68,7 @@ func TestParseFromYamlHttp(t *testing.T) {
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-http-instance": http.Config{
|
||||
Name: "my-http-instance",
|
||||
Kind: http.SourceKind,
|
||||
Type: http.SourceType,
|
||||
BaseURL: "http://test_server/",
|
||||
Timeout: "10s",
|
||||
DefaultHeaders: map[string]string{"Authorization": "test_header", "Custom-Header": "custom"},
|
||||
|
||||
@@ -15,7 +15,9 @@ package looker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -31,14 +33,14 @@ import (
|
||||
v4 "github.com/looker-open-source/sdk-codegen/go/sdk/v4"
|
||||
)
|
||||
|
||||
const SourceKind string = "looker"
|
||||
const SourceType string = "looker"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,7 +64,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
BaseURL string `yaml:"base_url" validate:"required"`
|
||||
ClientId string `yaml:"client_id"`
|
||||
ClientSecret string `yaml:"client_secret"`
|
||||
@@ -77,8 +79,8 @@ type Config struct {
|
||||
SessionLength int64 `yaml:"sessionLength"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
// Initialize initializes a Looker Source instance.
|
||||
@@ -152,8 +154,8 @@ type Source struct {
|
||||
AuthTokenHeaderName string
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -208,6 +210,49 @@ func (s *Source) LookerSessionLength() int64 {
|
||||
return s.SessionLength
|
||||
}
|
||||
|
||||
// Make types for RoundTripper
|
||||
type transportWithAuthHeader struct {
|
||||
Base http.RoundTripper
|
||||
AuthToken string
|
||||
}
|
||||
|
||||
func (t *transportWithAuthHeader) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req.Header.Set("x-looker-appid", "go-sdk")
|
||||
req.Header.Set("Authorization", t.AuthToken)
|
||||
return t.Base.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (s *Source) GetLookerSDK(accessToken string) (*v4.LookerSDK, error) {
|
||||
if s.UseClientAuthorization() {
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("no access token supplied with request")
|
||||
}
|
||||
|
||||
session := rtl.NewAuthSession(*s.LookerApiSettings())
|
||||
// Configure base transport with TLS
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: !s.LookerApiSettings().VerifySsl,
|
||||
},
|
||||
}
|
||||
|
||||
// Build transport for end user token
|
||||
session.Client = http.Client{
|
||||
Transport: &transportWithAuthHeader{
|
||||
Base: transport,
|
||||
AuthToken: accessToken,
|
||||
},
|
||||
}
|
||||
// return SDK with new Transport
|
||||
return v4.NewLookerSDK(session), nil
|
||||
}
|
||||
|
||||
if s.LookerClient() == nil {
|
||||
return nil, fmt.Errorf("client id or client secret not valid")
|
||||
}
|
||||
return s.LookerClient(), nil
|
||||
}
|
||||
|
||||
func initGoogleCloudConnection(ctx context.Context) (oauth2.TokenSource, error) {
|
||||
cred, err := google.FindDefaultCredentials(ctx, geminidataanalytics.DefaultAuthScopes()...)
|
||||
if err != nil {
|
||||
|
||||
@@ -44,7 +44,7 @@ func TestParseFromYamlLooker(t *testing.T) {
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-looker-instance": looker.Config{
|
||||
Name: "my-looker-instance",
|
||||
Kind: looker.SourceKind,
|
||||
Type: looker.SourceType,
|
||||
BaseURL: "http://example.looker.com/",
|
||||
ClientId: "jasdl;k;tjl",
|
||||
ClientSecret: "sdakl;jgflkasdfkfg",
|
||||
|
||||
@@ -27,14 +27,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "mindsdb"
|
||||
const SourceType string = "mindsdb"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
@@ -57,8 +57,8 @@ type Config struct {
|
||||
QueryTimeout string `yaml:"queryTimeout"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -86,8 +86,8 @@ type Source struct {
|
||||
Pool *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -159,7 +159,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
|
||||
|
||||
func initMindsDBConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
// Configure the driver to connect to the database
|
||||
|
||||
@@ -45,7 +45,7 @@ func TestParseFromYamlMindsDB(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-mindsdb-instance": mindsdb.Config{
|
||||
Name: "my-mindsdb-instance",
|
||||
Kind: mindsdb.SourceKind,
|
||||
Type: mindsdb.SourceType,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
@@ -70,7 +70,7 @@ func TestParseFromYamlMindsDB(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-mindsdb-instance": mindsdb.Config{
|
||||
Name: "my-mindsdb-instance",
|
||||
Kind: mindsdb.SourceKind,
|
||||
Type: mindsdb.SourceType,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
|
||||
@@ -16,24 +16,27 @@ package mongodb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "mongodb"
|
||||
const SourceType string = "mongodb"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,12 +50,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Uri string `yaml:"uri" validate:"required"` // MongoDB Atlas connection URI
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -81,8 +84,8 @@ type Source struct {
|
||||
Client *mongo.Client
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -93,9 +96,204 @@ func (s *Source) MongoClient() *mongo.Client {
|
||||
return s.Client
|
||||
}
|
||||
|
||||
func parseData(ctx context.Context, cur *mongo.Cursor) ([]any, error) {
|
||||
var data = []any{}
|
||||
err := cur.All(ctx, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var final []any
|
||||
for _, item := range data {
|
||||
tmp, _ := bson.MarshalExtJSON(item, false, false)
|
||||
var tmp2 any
|
||||
err = json.Unmarshal(tmp, &tmp2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
final = append(final, tmp2)
|
||||
}
|
||||
return final, err
|
||||
}
|
||||
|
||||
func (s *Source) Aggregate(ctx context.Context, pipelineString string, canonical, readOnly bool, database, collection string) ([]any, error) {
|
||||
var pipeline = []bson.M{}
|
||||
err := bson.UnmarshalExtJSON([]byte(pipelineString), canonical, &pipeline)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if readOnly {
|
||||
//fail if we do a merge or an out
|
||||
for _, stage := range pipeline {
|
||||
for key := range stage {
|
||||
if key == "$merge" || key == "$out" {
|
||||
return nil, fmt.Errorf("this is not a read-only pipeline: %+v", stage)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cur, err := s.MongoClient().Database(database).Collection(collection).Aggregate(ctx, pipeline)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cur.Close(ctx)
|
||||
res, err := parseData(ctx, cur)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res == nil {
|
||||
return []any{}, nil
|
||||
}
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (s *Source) Find(ctx context.Context, filterString, database, collection string, opts *options.FindOptions) ([]any, error) {
|
||||
var filter = bson.D{}
|
||||
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cur, err := s.MongoClient().Database(database).Collection(collection).Find(ctx, filter, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cur.Close(ctx)
|
||||
return parseData(ctx, cur)
|
||||
}
|
||||
|
||||
func (s *Source) FindOne(ctx context.Context, filterString, database, collection string, opts *options.FindOneOptions) ([]any, error) {
|
||||
var filter = bson.D{}
|
||||
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := s.MongoClient().Database(database).Collection(collection).FindOne(ctx, filter, opts)
|
||||
if res.Err() != nil {
|
||||
return nil, res.Err()
|
||||
}
|
||||
|
||||
var data any
|
||||
err = res.Decode(&data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var final []any
|
||||
tmp, _ := bson.MarshalExtJSON(data, false, false)
|
||||
var tmp2 any
|
||||
err = json.Unmarshal(tmp, &tmp2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
final = append(final, tmp2)
|
||||
|
||||
return final, err
|
||||
}
|
||||
|
||||
func (s *Source) InsertMany(ctx context.Context, jsonData string, canonical bool, database, collection string) ([]any, error) {
|
||||
var data = []any{}
|
||||
err := bson.UnmarshalExtJSON([]byte(jsonData), canonical, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := s.MongoClient().Database(database).Collection(collection).InsertMany(ctx, data, options.InsertMany())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.InsertedIDs, nil
|
||||
}
|
||||
|
||||
func (s *Source) InsertOne(ctx context.Context, jsonData string, canonical bool, database, collection string) (any, error) {
|
||||
var data any
|
||||
err := bson.UnmarshalExtJSON([]byte(jsonData), canonical, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := s.MongoClient().Database(database).Collection(collection).InsertOne(ctx, data, options.InsertOne())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.InsertedID, nil
|
||||
}
|
||||
|
||||
func (s *Source) UpdateMany(ctx context.Context, filterString string, canonical bool, updateString, database, collection string, upsert bool) ([]any, error) {
|
||||
var filter = bson.D{}
|
||||
err := bson.UnmarshalExtJSON([]byte(filterString), canonical, &filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal filter string: %w", err)
|
||||
}
|
||||
var update = bson.D{}
|
||||
err = bson.UnmarshalExtJSON([]byte(updateString), false, &update)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal update string: %w", err)
|
||||
}
|
||||
|
||||
res, err := s.MongoClient().Database(database).Collection(collection).UpdateMany(ctx, filter, update, options.Update().SetUpsert(upsert))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error updating collection: %w", err)
|
||||
}
|
||||
return []any{res.ModifiedCount, res.UpsertedCount, res.MatchedCount}, nil
|
||||
}
|
||||
|
||||
func (s *Source) UpdateOne(ctx context.Context, filterString string, canonical bool, updateString, database, collection string, upsert bool) (any, error) {
|
||||
var filter = bson.D{}
|
||||
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal filter string: %w", err)
|
||||
}
|
||||
var update = bson.D{}
|
||||
err = bson.UnmarshalExtJSON([]byte(updateString), canonical, &update)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal update string: %w", err)
|
||||
}
|
||||
|
||||
res, err := s.MongoClient().Database(database).Collection(collection).UpdateOne(ctx, filter, update, options.Update().SetUpsert(upsert))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error updating collection: %w", err)
|
||||
}
|
||||
return res.ModifiedCount, nil
|
||||
}
|
||||
|
||||
func (s *Source) DeleteMany(ctx context.Context, filterString, database, collection string) (any, error) {
|
||||
var filter = bson.D{}
|
||||
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := s.MongoClient().Database(database).Collection(collection).DeleteMany(ctx, filter, options.Delete())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if res.DeletedCount == 0 {
|
||||
return nil, errors.New("no document found")
|
||||
}
|
||||
return res.DeletedCount, nil
|
||||
}
|
||||
|
||||
func (s *Source) DeleteOne(ctx context.Context, filterString, database, collection string) (any, error) {
|
||||
var filter = bson.D{}
|
||||
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := s.MongoClient().Database(database).Collection(collection).DeleteOne(ctx, filter, options.Delete())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.DeletedCount, nil
|
||||
}
|
||||
|
||||
func initMongoDBClient(ctx context.Context, tracer trace.Tracer, name, uri string) (*mongo.Client, error) {
|
||||
// Start a tracing span
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
|
||||
@@ -41,7 +41,7 @@ func TestParseFromYamlMongoDB(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"mongo-db": mongodb.Config{
|
||||
Name: "mongo-db",
|
||||
Kind: mongodb.SourceKind,
|
||||
Type: mongodb.SourceType,
|
||||
Uri: "mongodb+srv://username:password@host/dbname",
|
||||
},
|
||||
},
|
||||
|
||||
@@ -28,14 +28,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "mssql"
|
||||
const SourceType string = "mssql"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
type Config struct {
|
||||
// Cloud SQL MSSQL configs
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
@@ -59,9 +59,9 @@ type Config struct {
|
||||
Encrypt string `yaml:"encrypt"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
func (r Config) SourceConfigType() string {
|
||||
// Returns Cloud SQL MSSQL source kind
|
||||
return SourceKind
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -91,9 +91,9 @@ type Source struct {
|
||||
Db *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
func (s *Source) SourceType() string {
|
||||
// Returns Cloud SQL MSSQL source kind
|
||||
return SourceKind
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -156,7 +156,7 @@ func initMssqlConnection(
|
||||
error,
|
||||
) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
|
||||
@@ -45,7 +45,7 @@ func TestParseFromYamlMssql(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-mssql-instance": mssql.Config{
|
||||
Name: "my-mssql-instance",
|
||||
Kind: mssql.SourceKind,
|
||||
Type: mssql.SourceType,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
@@ -70,7 +70,7 @@ func TestParseFromYamlMssql(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-mssql-instance": mssql.Config{
|
||||
Name: "my-mssql-instance",
|
||||
Kind: mssql.SourceKind,
|
||||
Type: mssql.SourceType,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
|
||||
@@ -30,14 +30,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "mysql"
|
||||
const SourceType string = "mysql"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
@@ -61,8 +61,8 @@ type Config struct {
|
||||
QueryParams map[string]string `yaml:"queryParams"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -90,8 +90,8 @@ type Source struct {
|
||||
Pool *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -158,7 +158,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
|
||||
|
||||
func initMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string, queryParams map[string]string) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
// Build query parameters via url.Values for deterministic order and proper escaping.
|
||||
|
||||
@@ -50,7 +50,7 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-mysql-instance": mysql.Config{
|
||||
Name: "my-mysql-instance",
|
||||
Kind: mysql.SourceKind,
|
||||
Type: mysql.SourceType,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
@@ -75,7 +75,7 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-mysql-instance": mysql.Config{
|
||||
Name: "my-mysql-instance",
|
||||
Kind: mysql.SourceKind,
|
||||
Type: mysql.SourceType,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
@@ -103,7 +103,7 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-mysql-instance": mysql.Config{
|
||||
Name: "my-mysql-instance",
|
||||
Kind: mysql.SourceKind,
|
||||
Type: mysql.SourceType,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
@@ -211,7 +211,7 @@ func TestFailInitialization(t *testing.T) {
|
||||
|
||||
cfg := mysql.Config{
|
||||
Name: "instance",
|
||||
Kind: "mysql",
|
||||
Type: "mysql",
|
||||
Host: "localhost",
|
||||
Port: "3306",
|
||||
Database: "db",
|
||||
|
||||
@@ -29,7 +29,7 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "neo4j"
|
||||
const SourceType string = "neo4j"
|
||||
|
||||
var sourceClassifier *classifier.QueryClassifier = classifier.NewQueryClassifier()
|
||||
|
||||
@@ -37,8 +37,8 @@ var sourceClassifier *classifier.QueryClassifier = classifier.NewQueryClassifier
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,15 +52,15 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Uri string `yaml:"uri" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
Password string `yaml:"password" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -91,8 +91,8 @@ type Source struct {
|
||||
Driver neo4j.DriverWithContext
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -182,7 +182,7 @@ func addPlanChildren(p neo4j.Plan) []map[string]any {
|
||||
|
||||
func initNeo4jDriver(ctx context.Context, tracer trace.Tracer, uri, user, password, name string) (neo4j.DriverWithContext, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
auth := neo4j.BasicAuth(user, password, "")
|
||||
|
||||
@@ -44,7 +44,7 @@ func TestParseFromYamlNeo4j(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-neo4j-instance": neo4j.Config{
|
||||
Name: "my-neo4j-instance",
|
||||
Kind: neo4j.SourceKind,
|
||||
Type: neo4j.SourceType,
|
||||
Uri: "neo4j+s://my-host:7687",
|
||||
Database: "my_db",
|
||||
User: "my_user",
|
||||
|
||||
@@ -27,14 +27,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "oceanbase"
|
||||
const SourceType string = "oceanbase"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
@@ -57,8 +57,8 @@ type Config struct {
|
||||
QueryTimeout string `yaml:"queryTimeout"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -86,8 +86,8 @@ type Source struct {
|
||||
Pool *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -153,7 +153,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
|
||||
}
|
||||
|
||||
func initOceanBaseConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
|
||||
_, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
_, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", user, pass, host, port, dbname)
|
||||
|
||||
@@ -46,7 +46,7 @@ func TestParseFromYamlOceanBase(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-oceanbase-instance": oceanbase.Config{
|
||||
Name: "my-oceanbase-instance",
|
||||
Kind: oceanbase.SourceKind,
|
||||
Type: oceanbase.SourceType,
|
||||
Host: "0.0.0.0",
|
||||
Port: "2881",
|
||||
Database: "ob_db",
|
||||
@@ -71,7 +71,7 @@ func TestParseFromYamlOceanBase(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-oceanbase-instance": oceanbase.Config{
|
||||
Name: "my-oceanbase-instance",
|
||||
Kind: oceanbase.SourceKind,
|
||||
Type: oceanbase.SourceType,
|
||||
Host: "0.0.0.0",
|
||||
Port: "2881",
|
||||
Database: "ob_db",
|
||||
|
||||
@@ -18,14 +18,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "oracle"
|
||||
const SourceType string = "oracle"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
ConnectionString string `yaml:"connectionString,omitempty"`
|
||||
TnsAlias string `yaml:"tnsAlias,omitempty"`
|
||||
TnsAdmin string `yaml:"tnsAdmin,omitempty"`
|
||||
@@ -95,8 +95,8 @@ func (c Config) validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -124,8 +124,8 @@ type Source struct {
|
||||
DB *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -239,7 +239,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
|
||||
|
||||
func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Config) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, config.Name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, config.Name)
|
||||
defer span.End()
|
||||
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
|
||||
@@ -33,7 +33,7 @@ func TestParseFromYamlOracle(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-oracle-cs": oracle.Config{
|
||||
Name: "my-oracle-cs",
|
||||
Kind: oracle.SourceKind,
|
||||
Type: oracle.SourceType,
|
||||
ConnectionString: "my-host:1521/XEPDB1",
|
||||
User: "my_user",
|
||||
Password: "my_pass",
|
||||
@@ -56,7 +56,7 @@ func TestParseFromYamlOracle(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-oracle-host": oracle.Config{
|
||||
Name: "my-oracle-host",
|
||||
Kind: oracle.SourceKind,
|
||||
Type: oracle.SourceType,
|
||||
Host: "my-host",
|
||||
Port: 1521,
|
||||
ServiceName: "ORCLPDB",
|
||||
@@ -81,7 +81,7 @@ func TestParseFromYamlOracle(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-oracle-tns-oci": oracle.Config{
|
||||
Name: "my-oracle-tns-oci",
|
||||
Kind: oracle.SourceKind,
|
||||
Type: oracle.SourceType,
|
||||
TnsAlias: "FINANCE_DB",
|
||||
TnsAdmin: "/opt/oracle/network/admin",
|
||||
User: "my_user",
|
||||
|
||||
@@ -28,14 +28,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "postgres"
|
||||
const SourceType string = "postgres"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
@@ -58,8 +58,8 @@ type Config struct {
|
||||
QueryParams map[string]string `yaml:"queryParams"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -87,8 +87,8 @@ type Source struct {
|
||||
Pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -128,7 +128,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
|
||||
|
||||
func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string) (*pgxpool.Pool, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -47,7 +47,7 @@ func TestParseFromYamlPostgres(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-pg-instance": postgres.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: postgres.SourceKind,
|
||||
Type: postgres.SourceType,
|
||||
Host: "my-host",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
@@ -74,7 +74,7 @@ func TestParseFromYamlPostgres(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-pg-instance": postgres.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: postgres.SourceKind,
|
||||
Type: postgres.SourceType,
|
||||
Host: "my-host",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
|
||||
@@ -24,14 +24,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "redis"
|
||||
const SourceType string = "redis"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Address []string `yaml:"address" validate:"required"`
|
||||
Username string `yaml:"username"`
|
||||
Password string `yaml:"password"`
|
||||
@@ -54,8 +54,8 @@ type Config struct {
|
||||
ClusterEnabled bool `yaml:"clusterEnabled"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
// RedisClient is an interface for `redis.Client` and `redis.ClusterClient
|
||||
@@ -141,8 +141,8 @@ type Source struct {
|
||||
Client RedisClient
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
|
||||
@@ -43,7 +43,7 @@ func TestParseFromYamlRedis(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-redis-instance": redis.Config{
|
||||
Name: "my-redis-instance",
|
||||
Kind: redis.SourceKind,
|
||||
Type: redis.SourceType,
|
||||
Address: []string{"127.0.0.1"},
|
||||
ClusterEnabled: false,
|
||||
UseGCPIAM: false,
|
||||
@@ -66,7 +66,7 @@ func TestParseFromYamlRedis(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-redis-instance": redis.Config{
|
||||
Name: "my-redis-instance",
|
||||
Kind: redis.SourceKind,
|
||||
Type: redis.SourceType,
|
||||
Address: []string{"127.0.0.1"},
|
||||
Password: "my-pass",
|
||||
Database: 1,
|
||||
|
||||
@@ -16,25 +16,31 @@ package serverlessspark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
|
||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
longrunning "cloud.google.com/go/longrunning/autogen"
|
||||
"cloud.google.com/go/longrunning/autogen/longrunningpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/api/iterator"
|
||||
"google.golang.org/api/option"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
const SourceKind string = "serverless-spark"
|
||||
const SourceType string = "serverless-spark"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,13 +54,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Location string `yaml:"location" validate:"required"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -88,8 +94,8 @@ type Source struct {
|
||||
OpsClient *longrunning.OperationsClient
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -121,3 +127,168 @@ func (s *Source) Close() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Source) CancelOperation(ctx context.Context, operation string) (any, error) {
|
||||
req := &longrunningpb.CancelOperationRequest{
|
||||
Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", s.GetProject(), s.GetLocation(), operation),
|
||||
}
|
||||
client, err := s.GetOperationsClient(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get operations client: %w", err)
|
||||
}
|
||||
err = client.CancelOperation(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to cancel operation: %w", err)
|
||||
}
|
||||
return fmt.Sprintf("Cancelled [%s].", operation), nil
|
||||
}
|
||||
|
||||
func (s *Source) CreateBatch(ctx context.Context, batch *dataprocpb.Batch) (map[string]any, error) {
|
||||
req := &dataprocpb.CreateBatchRequest{
|
||||
Parent: fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation()),
|
||||
Batch: batch,
|
||||
}
|
||||
|
||||
client := s.GetBatchControllerClient()
|
||||
op, err := client.CreateBatch(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create batch: %w", err)
|
||||
}
|
||||
meta, err := op.Metadata()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get create batch op metadata: %w", err)
|
||||
}
|
||||
|
||||
projectID, location, batchID, err := ExtractBatchDetails(meta.GetBatch())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error extracting batch details from name %q: %v", meta.GetBatch(), err)
|
||||
}
|
||||
consoleUrl := BatchConsoleURL(projectID, location, batchID)
|
||||
logsUrl := BatchLogsURL(projectID, location, batchID, meta.GetCreateTime().AsTime(), time.Time{})
|
||||
|
||||
wrappedResult := map[string]any{
|
||||
"opMetadata": meta,
|
||||
"consoleUrl": consoleUrl,
|
||||
"logsUrl": logsUrl,
|
||||
}
|
||||
return wrappedResult, nil
|
||||
}
|
||||
|
||||
// ListBatchesResponse is the response from the list batches API.
|
||||
type ListBatchesResponse struct {
|
||||
Batches []Batch `json:"batches"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
}
|
||||
|
||||
// Batch represents a single batch job.
|
||||
type Batch struct {
|
||||
Name string `json:"name"`
|
||||
UUID string `json:"uuid"`
|
||||
State string `json:"state"`
|
||||
Creator string `json:"creator"`
|
||||
CreateTime string `json:"createTime"`
|
||||
Operation string `json:"operation"`
|
||||
ConsoleURL string `json:"consoleUrl"`
|
||||
LogsURL string `json:"logsUrl"`
|
||||
}
|
||||
|
||||
func (s *Source) ListBatches(ctx context.Context, ps *int, pt, filter string) (any, error) {
|
||||
client := s.GetBatchControllerClient()
|
||||
parent := fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation())
|
||||
req := &dataprocpb.ListBatchesRequest{
|
||||
Parent: parent,
|
||||
OrderBy: "create_time desc",
|
||||
}
|
||||
|
||||
if ps != nil {
|
||||
req.PageSize = int32(*ps)
|
||||
}
|
||||
if pt != "" {
|
||||
req.PageToken = pt
|
||||
}
|
||||
if filter != "" {
|
||||
req.Filter = filter
|
||||
}
|
||||
|
||||
it := client.ListBatches(ctx, req)
|
||||
pager := iterator.NewPager(it, int(req.PageSize), req.PageToken)
|
||||
|
||||
var batchPbs []*dataprocpb.Batch
|
||||
nextPageToken, err := pager.NextPage(&batchPbs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list batches: %w", err)
|
||||
}
|
||||
|
||||
batches, err := ToBatches(batchPbs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ListBatchesResponse{Batches: batches, NextPageToken: nextPageToken}, nil
|
||||
}
|
||||
|
||||
// ToBatches converts a slice of protobuf Batch messages to a slice of Batch structs.
|
||||
func ToBatches(batchPbs []*dataprocpb.Batch) ([]Batch, error) {
|
||||
batches := make([]Batch, 0, len(batchPbs))
|
||||
for _, batchPb := range batchPbs {
|
||||
consoleUrl, err := BatchConsoleURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating console url: %v", err)
|
||||
}
|
||||
logsUrl, err := BatchLogsURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating logs url: %v", err)
|
||||
}
|
||||
batch := Batch{
|
||||
Name: batchPb.Name,
|
||||
UUID: batchPb.Uuid,
|
||||
State: batchPb.State.Enum().String(),
|
||||
Creator: batchPb.Creator,
|
||||
CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339),
|
||||
Operation: batchPb.Operation,
|
||||
ConsoleURL: consoleUrl,
|
||||
LogsURL: logsUrl,
|
||||
}
|
||||
batches = append(batches, batch)
|
||||
}
|
||||
return batches, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetBatch(ctx context.Context, name string) (map[string]any, error) {
|
||||
client := s.GetBatchControllerClient()
|
||||
req := &dataprocpb.GetBatchRequest{
|
||||
Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", s.GetProject(), s.GetLocation(), name),
|
||||
}
|
||||
|
||||
batchPb, err := client.GetBatch(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get batch: %w", err)
|
||||
}
|
||||
|
||||
jsonBytes, err := protojson.Marshal(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal batch to JSON: %w", err)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(jsonBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal batch JSON: %w", err)
|
||||
}
|
||||
|
||||
consoleUrl, err := BatchConsoleURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating console url: %v", err)
|
||||
}
|
||||
logsUrl, err := BatchLogsURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating logs url: %v", err)
|
||||
}
|
||||
|
||||
wrappedResult := map[string]any{
|
||||
"consoleUrl": consoleUrl,
|
||||
"logsUrl": logsUrl,
|
||||
"batch": result,
|
||||
}
|
||||
|
||||
return wrappedResult, nil
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func TestParseFromYamlServerlessSpark(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": serverlessspark.Config{
|
||||
Name: "my-instance",
|
||||
Kind: serverlessspark.SourceKind,
|
||||
Type: serverlessspark.SourceType,
|
||||
Project: "my-project",
|
||||
Location: "my-location",
|
||||
},
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
// Copyright 2025 Google LLC
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package common
|
||||
package serverlessspark
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -23,13 +23,13 @@ import (
|
||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
)
|
||||
|
||||
var batchFullNameRegex = regexp.MustCompile(`projects/(?P<project>[^/]+)/locations/(?P<location>[^/]+)/batches/(?P<batch_id>[^/]+)`)
|
||||
|
||||
const (
|
||||
logTimeBufferBefore = 1 * time.Minute
|
||||
logTimeBufferAfter = 10 * time.Minute
|
||||
)
|
||||
|
||||
var batchFullNameRegex = regexp.MustCompile(`projects/(?P<project>[^/]+)/locations/(?P<location>[^/]+)/batches/(?P<batch_id>[^/]+)`)
|
||||
|
||||
// Extract BatchDetails extracts the project ID, location, and batch ID from a fully qualified batch name.
|
||||
func ExtractBatchDetails(batchName string) (projectID, location, batchID string, err error) {
|
||||
matches := batchFullNameRegex.FindStringSubmatch(batchName)
|
||||
@@ -39,26 +39,6 @@ func ExtractBatchDetails(batchName string) (projectID, location, batchID string,
|
||||
return matches[1], matches[2], matches[3], nil
|
||||
}
|
||||
|
||||
// BatchConsoleURLFromProto builds a URL to the Google Cloud Console linking to the batch summary page.
|
||||
func BatchConsoleURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
|
||||
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return BatchConsoleURL(projectID, location, batchID), nil
|
||||
}
|
||||
|
||||
// BatchLogsURLFromProto builds a URL to the Google Cloud Console showing Cloud Logging for the given batch and time range.
|
||||
func BatchLogsURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
|
||||
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
createTime := batchPb.GetCreateTime().AsTime()
|
||||
stateTime := batchPb.GetStateTime().AsTime()
|
||||
return BatchLogsURL(projectID, location, batchID, createTime, stateTime), nil
|
||||
}
|
||||
|
||||
// BatchConsoleURL builds a URL to the Google Cloud Console linking to the batch summary page.
|
||||
func BatchConsoleURL(projectID, location, batchID string) string {
|
||||
return fmt.Sprintf("https://console.cloud.google.com/dataproc/batches/%s/%s/summary?project=%s", location, batchID, projectID)
|
||||
@@ -89,3 +69,23 @@ resource.labels.batch_id="%s"`
|
||||
|
||||
return "https://console.cloud.google.com/logs/viewer?" + v.Encode()
|
||||
}
|
||||
|
||||
// BatchConsoleURLFromProto builds a URL to the Google Cloud Console linking to the batch summary page.
|
||||
func BatchConsoleURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
|
||||
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return BatchConsoleURL(projectID, location, batchID), nil
|
||||
}
|
||||
|
||||
// BatchLogsURLFromProto builds a URL to the Google Cloud Console showing Cloud Logging for the given batch and time range.
|
||||
func BatchLogsURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
|
||||
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
createTime := batchPb.GetCreateTime().AsTime()
|
||||
stateTime := batchPb.GetStateTime().AsTime()
|
||||
return BatchLogsURL(projectID, location, batchID, createTime, stateTime), nil
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
// Copyright 2025 Google LLC
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -12,19 +12,20 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package common
|
||||
package serverlessspark_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
func TestExtractBatchDetails_Success(t *testing.T) {
|
||||
batchName := "projects/my-project/locations/us-central1/batches/my-batch"
|
||||
projectID, location, batchID, err := ExtractBatchDetails(batchName)
|
||||
projectID, location, batchID, err := serverlessspark.ExtractBatchDetails(batchName)
|
||||
if err != nil {
|
||||
t.Errorf("ExtractBatchDetails() error = %v, want no error", err)
|
||||
return
|
||||
@@ -45,7 +46,7 @@ func TestExtractBatchDetails_Success(t *testing.T) {
|
||||
|
||||
func TestExtractBatchDetails_Failure(t *testing.T) {
|
||||
batchName := "invalid-name"
|
||||
_, _, _, err := ExtractBatchDetails(batchName)
|
||||
_, _, _, err := serverlessspark.ExtractBatchDetails(batchName)
|
||||
wantErr := "failed to parse batch name: invalid-name"
|
||||
if err == nil || err.Error() != wantErr {
|
||||
t.Errorf("ExtractBatchDetails() error = %v, want %v", err, wantErr)
|
||||
@@ -53,7 +54,7 @@ func TestExtractBatchDetails_Failure(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBatchConsoleURL(t *testing.T) {
|
||||
got := BatchConsoleURL("my-project", "us-central1", "my-batch")
|
||||
got := serverlessspark.BatchConsoleURL("my-project", "us-central1", "my-batch")
|
||||
want := "https://console.cloud.google.com/dataproc/batches/us-central1/my-batch/summary?project=my-project"
|
||||
if got != want {
|
||||
t.Errorf("BatchConsoleURL() = %v, want %v", got, want)
|
||||
@@ -63,7 +64,7 @@ func TestBatchConsoleURL(t *testing.T) {
|
||||
func TestBatchLogsURL(t *testing.T) {
|
||||
startTime := time.Date(2025, 10, 1, 5, 0, 0, 0, time.UTC)
|
||||
endTime := time.Date(2025, 10, 1, 6, 0, 0, 0, time.UTC)
|
||||
got := BatchLogsURL("my-project", "us-central1", "my-batch", startTime, endTime)
|
||||
got := serverlessspark.BatchLogsURL("my-project", "us-central1", "my-batch", startTime, endTime)
|
||||
want := "https://console.cloud.google.com/logs/viewer?advancedFilter=" +
|
||||
"resource.type%3D%22cloud_dataproc_batch%22" +
|
||||
"%0Aresource.labels.project_id%3D%22my-project%22" +
|
||||
@@ -82,7 +83,7 @@ func TestBatchConsoleURLFromProto(t *testing.T) {
|
||||
batchPb := &dataprocpb.Batch{
|
||||
Name: "projects/my-project/locations/us-central1/batches/my-batch",
|
||||
}
|
||||
got, err := BatchConsoleURLFromProto(batchPb)
|
||||
got, err := serverlessspark.BatchConsoleURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
t.Fatalf("BatchConsoleURLFromProto() error = %v", err)
|
||||
}
|
||||
@@ -100,7 +101,7 @@ func TestBatchLogsURLFromProto(t *testing.T) {
|
||||
CreateTime: timestamppb.New(createTime),
|
||||
StateTime: timestamppb.New(stateTime),
|
||||
}
|
||||
got, err := BatchLogsURLFromProto(batchPb)
|
||||
got, err := serverlessspark.BatchLogsURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
t.Fatalf("BatchLogsURLFromProto() error = %v", err)
|
||||
}
|
||||
@@ -29,15 +29,15 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
// SourceKind for SingleStore source
|
||||
const SourceKind string = "singlestore"
|
||||
// SourceType for SingleStore source
|
||||
const SourceType string = "singlestore"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
// Config holds the configuration parameters for connecting to a SingleStore database.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
@@ -61,9 +61,9 @@ type Config struct {
|
||||
QueryTimeout string `yaml:"queryTimeout"`
|
||||
}
|
||||
|
||||
// SourceConfigKind returns the kind of the source configuration.
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
// SourceConfigType returns the kind of the source configuration.
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
// Initialize sets up the SingleStore connection pool and returns a Source.
|
||||
@@ -93,9 +93,9 @@ type Source struct {
|
||||
Pool *sql.DB
|
||||
}
|
||||
|
||||
// SourceKind returns the kind of the source configuration.
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
// SourceType returns the kind of the source configuration.
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -162,7 +162,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
|
||||
|
||||
func initSingleStoreConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
// Configure the driver to connect to the database
|
||||
|
||||
@@ -45,7 +45,7 @@ func TestParseFromYaml(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-s2-instance": singlestore.Config{
|
||||
Name: "my-s2-instance",
|
||||
Kind: singlestore.SourceKind,
|
||||
Type: singlestore.SourceType,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
@@ -70,7 +70,7 @@ func TestParseFromYaml(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-s2-instance": singlestore.Config{
|
||||
Name: "my-s2-instance",
|
||||
Kind: singlestore.SourceKind,
|
||||
Type: singlestore.SourceType,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
|
||||
@@ -25,14 +25,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "snowflake"
|
||||
const SourceType string = "snowflake"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Type string `yaml:"kind" validate:"required"`
|
||||
Account string `yaml:"account" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
Password string `yaml:"password" validate:"required"`
|
||||
@@ -56,8 +56,8 @@ type Config struct {
|
||||
Role string `yaml:"role"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -85,8 +85,8 @@ type Source struct {
|
||||
DB *sqlx.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -137,7 +137,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
|
||||
|
||||
func initSnowflakeConnection(ctx context.Context, tracer trace.Tracer, name, account, user, password, database, schema, warehouse, role string) (*sqlx.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
// Set defaults for optional parameters
|
||||
|
||||
@@ -45,7 +45,7 @@ func TestParseFromYamlSnowflake(t *testing.T) {
|
||||
want: server.SourceConfigs{
|
||||
"my-snowflake-instance": snowflake.Config{
|
||||
Name: "my-snowflake-instance",
|
||||
Kind: snowflake.SourceKind,
|
||||
Type: snowflake.SourceType,
|
||||
Account: "my-account",
|
||||
User: "my_user",
|
||||
Password: "my_pass",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user