Compare commits

..

34 Commits

Author SHA1 Message Date
Anubhav Dhawan
57a2ade4c5 feat(npm): Bump server package versions to 0.25.0 and generate new platform-specific build artifacts. (#2284)
Releases the NPM package versions with latest toolbox release
(`v.0.25.0`).
2026-01-09 15:34:33 +05:30
Twisha Bansal
bcb3566ba3 update versions'
update versions
2026-01-09 14:32:43 +05:30
Twisha Bansal
47bbef10c5 update version 2026-01-09 14:32:43 +05:30
Twisha Bansal
44e782371e script for server release 2026-01-09 14:32:43 +05:30
Twisha Bansal
705be4018d update dep versions 2026-01-09 14:32:43 +05:30
Twisha Bansal
c968ed8925 rename file 2026-01-09 14:32:43 +05:30
Twisha Bansal
6f5911288d release dep package script 2026-01-09 14:32:43 +05:30
Twisha Bansal
cf42389d39 add script to update server packages 2026-01-09 14:32:43 +05:30
Twisha Bansal
49cb83ca5d update versions 2026-01-09 14:32:43 +05:30
Twisha Bansal
961c4726c7 bump version 2026-01-09 14:32:43 +05:30
Twisha Bansal
6e961ad48b update supported platforms 2026-01-09 14:32:43 +05:30
Twisha Bansal
d5a9031a65 bump version 2026-01-09 14:32:43 +05:30
Twisha Bansal
669b9b3b5d update linux package versions 2026-01-09 14:32:43 +05:30
Twisha Bansal
7bf04ddbb1 fix server package 2026-01-09 14:32:43 +05:30
Twisha Bansal
2705cf642c update server versions 2026-01-09 14:32:43 +05:30
Twisha Bansal
fb8f757eaf upate version 2026-01-09 14:32:43 +05:30
Twisha Bansal
6731e25cff remove apple quarantine 2026-01-09 14:32:43 +05:30
Twisha Bansal
0aece20ada Update downloadBinary.js 2026-01-09 14:32:43 +05:30
Twisha Bansal
401d0307d5 Update downloadBinary.js 2026-01-09 14:32:43 +05:30
Twisha Bansal
7479f5b871 Update downloadBinary.js 2026-01-09 14:32:43 +05:30
Twisha Bansal
a61a1d78a3 Update downloadBinary.js 2026-01-09 14:32:43 +05:30
Twisha Bansal
2fbf51968d Update testing.md 2026-01-09 14:32:43 +05:30
Twisha Bansal
54b56ee7cc fix 2026-01-09 14:32:43 +05:30
Twisha Bansal
39615a6011 fix path 2026-01-09 14:32:43 +05:30
Twisha Bansal
83562eb5ec fix package.json 2026-01-09 14:32:43 +05:30
Twisha Bansal
bb76fc6a84 fix testing 2026-01-09 14:32:43 +05:30
Twisha Bansal
21e4db2ed4 Update testing.md 2026-01-09 14:32:43 +05:30
Twisha Bansal
05e075b3fe Update testing.md 2026-01-09 14:32:43 +05:30
Twisha Bansal
e94eebc6af fix package 2026-01-09 14:32:43 +05:30
Twisha Bansal
e50dc777b7 Fixed readme 2026-01-09 14:32:43 +05:30
Twisha Bansal
1aaa84313c add markdown on how to test 2026-01-09 14:32:43 +05:30
Twisha Bansal
829e3168ff fix 2026-01-09 14:32:43 +05:30
Twisha Bansal
3d76f60401 fix 2026-01-09 14:32:43 +05:30
Twisha Bansal
d065b41598 add files 2026-01-09 14:32:43 +05:30
567 changed files with 5084 additions and 5084 deletions

View File

@@ -59,13 +59,6 @@ You can manually trigger the bot by commenting on your Pull Request:
* `/gemini summary`: Posts a summary of the changes in the pull request.
* `/gemini help`: Overview of the available commands
## Guidelines for Pull Requests
1. Please keep your PR small for more thorough review and easier updates. In case of regression, it also allows us to roll back a single feature instead of multiple ones.
1. For non-trivial changes, consider opening an issue and discussing it with the code owners first.
1. Provide a good PR description as a record of what change is being made and why it was made. Link to a GitHub issue if it exists.
1. Make sure your code is thoroughly tested with unit tests and integration tests. Remember to clean up the test instances properly in your code to avoid memory leaks.
## Adding a New Database Source or Tool
Please create an
@@ -92,11 +85,11 @@ implementation](https://github.com/googleapis/genai-toolbox/blob/main/internal/s
`newdb.go`. Create a `Config` struct to include all the necessary parameters
for connecting to the database (e.g., host, port, username, password, database
name) and a `Source` struct to store necessary parameters for tools (e.g.,
Name, Type, connection object, additional config).
Name, Kind, connection object, additional config).
* **Implement the
[`SourceConfig`](https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/internal/sources/sources.go#L57)
interface**. This interface requires two methods:
* `SourceConfigType() string`: Returns a unique string identifier for your
* `SourceConfigKind() string`: Returns a unique string identifier for your
data source (e.g., `"newdb"`).
* `Initialize(ctx context.Context, tracer trace.Tracer) (Source, error)`:
Creates a new instance of your data source and establishes a connection to
@@ -104,7 +97,7 @@ implementation](https://github.com/googleapis/genai-toolbox/blob/main/internal/s
* **Implement the
[`Source`](https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/internal/sources/sources.go#L63)
interface**. This interface requires one method:
* `SourceType() string`: Returns the same string identifier as `SourceConfigType()`.
* `SourceKind() string`: Returns the same string identifier as `SourceConfigKind()`.
* **Implement `init()`** to register the new Source.
* **Implement Unit Tests** in a file named `newdb_test.go`.
@@ -117,8 +110,6 @@ implementation](https://github.com/googleapis/genai-toolbox/blob/main/internal/s
We recommend looking at an [example tool
implementation](https://github.com/googleapis/genai-toolbox/tree/main/internal/tools/postgres/postgressql).
Remember to keep your PRs small. For example, if you are contributing a new Source, only include one or two core Tools within the same PR, the rest of the Tools can come in subsequent PRs.
* **Create a new directory** under `internal/tools` for your tool type (e.g., `internal/tools/newdb/newdbtool`).
* **Define a configuration struct** for your tool in a file named `newdbtool.go`.
Create a `Config` struct and a `Tool` struct to store necessary parameters for
@@ -126,7 +117,7 @@ tools.
* **Implement the
[`ToolConfig`](https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/internal/tools/tools.go#L61)
interface**. This interface requires one method:
* `ToolConfigType() string`: Returns a unique string identifier for your tool
* `ToolConfigKind() string`: Returns a unique string identifier for your tool
(e.g., `"newdb-tool"`).
* `Initialize(sources map[string]Source) (Tool, error)`: Creates a new
instance of your tool and validates that it can connect to the specified
@@ -172,8 +163,6 @@ tools.
parameters][temp-param-doc]. Only run this test if template
parameters apply to your tool.
* **Add additional tests** for the tools that are not covered by the predefined tests. Every tool must be tested!
* **Add the new database to the integration test workflow** in
[integration.cloudbuild.yaml](.ci/integration.cloudbuild.yaml).
@@ -243,7 +232,7 @@ resources.
| style | Update src code, with only formatting and whitespace updates (e.g. code formatter or linter changes). |
Pull requests should always add scope whenever possible. The scope is
formatted as `<scope-resource>/<scope-type>` (e.g., `sources/postgres`, or
formatted as `<scope-type>/<scope-kind>` (e.g., `sources/postgres`, or
`tools/mssql-sql`).
Ideally, **each PR covers only one scope**, if this is
@@ -255,4 +244,4 @@ resources.
* **PR Description:** PR description should **always** be included. It should
include a concise description of the changes, it's impact, along with a
summary of the solution. If the PR is related to a specific issue, the issue
number should be mentioned in the PR description (e.g. `Fixes #1`).
number should be mentioned in the PR description (e.g. `Fixes #1`).

View File

@@ -379,23 +379,6 @@ to approve PRs for main. TeamSync is used to create this team from the MDB
Group `toolbox-contributors`. Googlers who are developing for MCP-Toolbox
but aren't part of the core team should join this group.
### Issue/PR Triage and SLO
After an issue is created, maintainers will assign the following labels:
* `Priority` (defaulted to P0)
* `Type` (if applicable)
* `Product` (if applicable)
All incoming issues and PRs will follow the following SLO:
| Type | Priority | Objective |
|-----------------|----------|------------------------------------------------------------------------|
| Feature Request | P0 | Must respond within **5 days** |
| Process | P0 | Must respond within **5 days** |
| Bugs | P0 | Must respond within **5 days**, and resolve/closure within **14 days** |
| Bugs | P1 | Must respond within **7 days**, and resolve/closure within **90 days** |
| Bugs | P2 | Must respond within **30 days**
_Types that are not listed in the table do not adhere to any SLO._
### Releasing
Toolbox has two types of releases: versioned and continuous. It uses Google

View File

@@ -272,7 +272,7 @@ To run Toolbox from binary:
To run the server after pulling the [container image](#installing-the-server):
```sh
export VERSION=0.24.0 # Use the version you pulled
export VERSION=0.11.0 # Use the version you pulled
docker run -p 5000:5000 \
-v $(pwd)/tools.yaml:/app/tools.yaml \
us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION \
@@ -954,7 +954,7 @@ For more details on configuring different types of sources, see the
### Tools
The `tools` section of a `tools.yaml` define the actions an agent can take: what
type of tool it is, which source(s) it affects, what parameters it uses, etc.
kind of tool it is, which source(s) it affects, what parameters it uses, etc.
```yaml
tools:

View File

@@ -15,7 +15,6 @@
package cmd
import (
"bytes"
"context"
_ "embed"
"fmt"
@@ -93,7 +92,6 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudmonitoring"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcloneinstance"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatebackup"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatedatabase"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreateusers"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlgetinstances"
@@ -426,103 +424,6 @@ func parseEnv(input string) (string, error) {
return output, err
}
func convertToolsFile(ctx context.Context, raw []byte) ([]byte, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, err
}
// TODO: add embeddingmodels when available
keysToCheck := []string{"sources", "authServices", "authSources", "tools", "prompts", "toolsets"}
var input yaml.MapSlice
decoder := yaml.NewDecoder(bytes.NewReader(raw), yaml.UseOrderedMap())
if err := decoder.Decode(&input); err != nil {
return nil, err
}
// Convert raw MapSlice to a helper map for quick lookup
// while keeping the values as MapSlices to preserve internal order
lookup := make(map[string]yaml.MapSlice)
for _, item := range input {
key := item.Key.(string)
if slice, ok := item.Value.(yaml.MapSlice); ok {
// convert authSources to authServices
if key == "authSources" {
logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` instead")
key = "authServices"
}
// works even if lookup[key] is nil
lookup[key] = append(lookup[key], slice...)
} else {
// toolsfile is already v2
if key == "kind" {
return raw, nil
}
return nil, fmt.Errorf("'%s' is not a map", key)
}
}
// convert to tools file v2
var buf bytes.Buffer
encoder := yaml.NewEncoder(&buf)
for _, kind := range keysToCheck {
data, exists := lookup[kind]
if !exists {
// if this is skipped for all keys, the tools file is in v2
continue
}
// Transform each entry
for _, entry := range data {
entryName := entry.Key.(string)
entryBody := ProcessValue(entry.Value, kind == "toolsets")
transformed := yaml.MapSlice{
{Key: "kind", Value: kind},
{Key: "name", Value: entryName},
}
// Merge the transformed body into our result
if bodySlice, ok := entryBody.(yaml.MapSlice); ok {
transformed = append(transformed, bodySlice...)
} else {
return nil, fmt.Errorf("unable to convert entryBody to MapSlice")
}
if err := encoder.Encode(transformed); err != nil {
return nil, err
}
}
}
return buf.Bytes(), nil
}
// ProcessValue recursively looks for MapSlices to rename 'kind' -> 'type'
func ProcessValue(v any, isToolset bool) any {
switch val := v.(type) {
case yaml.MapSlice:
for i := range val {
// Perform renaming
if val[i].Key == "kind" {
val[i].Key = "type"
}
// Recursive call for nested values (e.g., nested objects or lists)
val[i].Value = ProcessValue(val[i].Value, false)
}
return val
case []any:
// Process lists: If it's a toolset top-level list, wrap it.
if isToolset {
return yaml.MapSlice{{Key: "tools", Value: val}}
}
// Otherwise, recurse into list items (to catch nested objects)
for i := range val {
val[i] = ProcessValue(val[i], false)
}
return val
default:
return val
}
}
// parseToolsFile parses the provided yaml into appropriate configs.
func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
var toolsFile ToolsFile

View File

@@ -23,14 +23,12 @@ import (
"os"
"path"
"path/filepath"
"reflect"
"regexp"
"runtime"
"strings"
"testing"
"time"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/auth/google"
@@ -496,235 +494,6 @@ func TestDefaultLogLevel(t *testing.T) {
}
}
func TestConvertToolsFile(t *testing.T) {
ctx, cancelCtx := context.WithTimeout(context.Background(), time.Minute)
defer cancelCtx()
pr, pw := io.Pipe()
defer pw.Close()
defer pr.Close()
logger, err := log.NewStdLogger(pw, pw, "DEBUG")
if err != nil {
t.Fatalf("failed to setup logger %s", err)
}
ctx = util.WithLogger(ctx, logger)
tcs := []struct {
desc string
in string
want string
isErr bool
errStr string
}{
{
desc: "basic convert",
in: `
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
tools:
example_tool:
kind: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
toolsets:
example_toolset:
- example_tool`,
want: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
---
kind: tools
name: example_tool
type: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
---
kind: toolsets
name: example_toolset
tools:
- example_tool`,
},
{
desc: "rearrange resource order",
in: `
tools:
example_tool:
kind: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
toolsets:
example_toolset:
- example_tool`,
want: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
---
kind: tools
name: example_tool
type: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
---
kind: toolsets
name: example_toolset
tools:
- example_tool`,
},
{
desc: "no convertion needed",
in: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
---
kind: tools
name: example_tool
type: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
---
kind: toolsets
name: example_toolset
tools:
- example_tool`,
want: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
---
kind: tools
name: example_tool
type: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
---
kind: toolsets
name: example_toolset
tools:
- example_tool`,
},
{
desc: "invalid source",
in: `sources: invalid`,
isErr: true,
errStr: "'sources' is not a map",
},
{
desc: "invalid toolset",
in: `toolsets: invalid`,
isErr: true,
errStr: "'toolsets' is not a map",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
output, err := convertToolsFile(ctx, []byte(tc.in))
if tc.isErr {
if err == nil {
t.Fatalf("missing error: %s", tc.errStr)
}
if err.Error() != tc.errStr {
t.Fatalf("invalid error string: got %s, want %s", err, tc.errStr)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
// ensures that the order is correct
var doc1, doc2 yaml.MapSlice
if err := yaml.Unmarshal(output, &doc1); err != nil {
t.Fatalf("unable to unmarshal output: %s", string(output))
}
if err := yaml.Unmarshal([]byte(tc.want), &doc2); err != nil {
t.Fatalf("unable to unmarshal output: %s", tc.want)
}
if !reflect.DeepEqual(doc1, doc2) {
t.Fatalf("incorrect output: got %s, want %s", doc1, doc2)
}
})
}
}
func TestParseToolFile(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
@@ -766,7 +535,7 @@ func TestParseToolFile(t *testing.T) {
Sources: server.SourceConfigs{
"my-pg-instance": cloudsqlpgsrc.Config{
Name: "my-pg-instance",
Type: cloudsqlpgsrc.SourceType,
Kind: cloudsqlpgsrc.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -779,7 +548,7 @@ func TestParseToolFile(t *testing.T) {
Tools: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Type: "postgres-sql",
Kind: "postgres-sql",
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
@@ -920,7 +689,7 @@ func TestParseToolFileWithAuth(t *testing.T) {
Sources: server.SourceConfigs{
"my-pg-instance": cloudsqlpgsrc.Config{
Name: "my-pg-instance",
Type: cloudsqlpgsrc.SourceType,
Kind: cloudsqlpgsrc.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -933,19 +702,19 @@ func TestParseToolFileWithAuth(t *testing.T) {
AuthServices: server.AuthServiceConfigs{
"my-google-service": google.Config{
Name: "my-google-service",
Type: google.AuthServiceType,
Kind: google.AuthServiceKind,
ClientID: "my-client-id",
},
"other-google-service": google.Config{
Name: "other-google-service",
Type: google.AuthServiceType,
Kind: google.AuthServiceKind,
ClientID: "other-client-id",
},
},
Tools: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Type: "postgres-sql",
Kind: "postgres-sql",
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
@@ -1020,7 +789,7 @@ func TestParseToolFileWithAuth(t *testing.T) {
Sources: server.SourceConfigs{
"my-pg-instance": cloudsqlpgsrc.Config{
Name: "my-pg-instance",
Type: cloudsqlpgsrc.SourceType,
Kind: cloudsqlpgsrc.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -1033,19 +802,19 @@ func TestParseToolFileWithAuth(t *testing.T) {
AuthSources: server.AuthServiceConfigs{
"my-google-service": google.Config{
Name: "my-google-service",
Type: google.AuthServiceType,
Kind: google.AuthServiceKind,
ClientID: "my-client-id",
},
"other-google-service": google.Config{
Name: "other-google-service",
Type: google.AuthServiceType,
Kind: google.AuthServiceKind,
ClientID: "other-client-id",
},
},
Tools: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Type: "postgres-sql",
Kind: "postgres-sql",
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
@@ -1122,7 +891,7 @@ func TestParseToolFileWithAuth(t *testing.T) {
Sources: server.SourceConfigs{
"my-pg-instance": cloudsqlpgsrc.Config{
Name: "my-pg-instance",
Type: cloudsqlpgsrc.SourceType,
Kind: cloudsqlpgsrc.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -1135,19 +904,19 @@ func TestParseToolFileWithAuth(t *testing.T) {
AuthServices: server.AuthServiceConfigs{
"my-google-service": google.Config{
Name: "my-google-service",
Type: google.AuthServiceType,
Kind: google.AuthServiceKind,
ClientID: "my-client-id",
},
"other-google-service": google.Config{
Name: "other-google-service",
Type: google.AuthServiceType,
Kind: google.AuthServiceKind,
ClientID: "other-client-id",
},
},
Tools: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Type: "postgres-sql",
Kind: "postgres-sql",
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
@@ -1293,7 +1062,7 @@ func TestEnvVarReplacement(t *testing.T) {
Sources: server.SourceConfigs{
"my-http-instance": httpsrc.Config{
Name: "my-http-instance",
Type: httpsrc.SourceType,
Kind: httpsrc.SourceKind,
BaseURL: "http://test_server/",
Timeout: "10s",
DefaultHeaders: map[string]string{"Authorization": "ACTUAL_HEADER"},
@@ -1303,19 +1072,19 @@ func TestEnvVarReplacement(t *testing.T) {
AuthServices: server.AuthServiceConfigs{
"my-google-service": google.Config{
Name: "my-google-service",
Type: google.AuthServiceType,
Kind: google.AuthServiceKind,
ClientID: "ACTUAL_CLIENT_ID",
},
"other-google-service": google.Config{
Name: "other-google-service",
Type: google.AuthServiceType,
Kind: google.AuthServiceKind,
ClientID: "ACTUAL_CLIENT_ID_2",
},
},
Tools: server.ToolConfigs{
"example_tool": http.Config{
Name: "example_tool",
Type: "http",
Kind: "http",
Source: "my-instance",
Method: "GET",
Path: "search?name=alice&pet=cat",
@@ -1724,7 +1493,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{
"cloud_sql_postgres_admin_tools": tools.ToolsetConfig{
Name: "cloud_sql_postgres_admin_tools",
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance", "create_backup"},
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance"},
},
},
},
@@ -1734,7 +1503,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{
"cloud_sql_mysql_admin_tools": tools.ToolsetConfig{
Name: "cloud_sql_mysql_admin_tools",
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup"},
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance"},
},
},
},
@@ -1744,7 +1513,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{
"cloud_sql_mssql_admin_tools": tools.ToolsetConfig{
Name: "cloud_sql_mssql_admin_tools",
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup"},
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance"},
},
},
},

View File

View File

@@ -48,7 +48,6 @@ instance, database and users:
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
* All `viewer` tools
* `create_database`
* `create_backup`
* `roles/cloudsql.admin`: Provides full control over all resources.
* All `editor` and `viewer` tools
* `create_instance`
@@ -300,7 +299,6 @@ instances and interacting with your database:
* **create_user**: Creates a new user in a Cloud SQL instance.
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
* **clone_instance**: Creates a clone of an existing Cloud SQL for SQL Server instance.
* **create_backup**: Creates a backup on a Cloud SQL instance.
{{< notice note >}}
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs

View File

@@ -48,7 +48,6 @@ database and users:
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
* All `viewer` tools
* `create_database`
* `create_backup`
* `roles/cloudsql.admin`: Provides full control over all resources.
* All `editor` and `viewer` tools
* `create_instance`
@@ -300,7 +299,6 @@ instances and interacting with your database:
* **create_user**: Creates a new user in a Cloud SQL instance.
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
* **clone_instance**: Creates a clone of an existing Cloud SQL for MySQL instance.
* **create_backup**: Creates a backup on a Cloud SQL instance.
{{< notice note >}}
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs

View File

@@ -48,7 +48,6 @@ instance, database and users:
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
* All `viewer` tools
* `create_database`
* `create_backup`
* `roles/cloudsql.admin`: Provides full control over all resources.
* All `editor` and `viewer` tools
* `create_instance`
@@ -300,7 +299,6 @@ instances and interacting with your database:
* **create_user**: Creates a new user in a Cloud SQL instance.
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
* **clone_instance**: Creates a clone of an existing Cloud SQL for PostgreSQL instance.
* **create_backup**: Creates a backup on a Cloud SQL instance.
{{< notice note >}}
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs

View File

@@ -187,7 +187,6 @@ See [Usage Examples](../reference/cli.md#examples).
manage existing resources.
* All `viewer` tools
* `create_database`
* `create_backup`
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
all resources.
* All `editor` and `viewer` tools
@@ -204,7 +203,6 @@ See [Usage Examples](../reference/cli.md#examples).
* `create_user`: Creates a new user in a Cloud SQL instance.
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
* `clone_instance`: Creates a clone for an existing Cloud SQL for MySQL instance.
* `create_backup`: Creates a backup on a Cloud SQL instance.
## Cloud SQL for PostgreSQL
@@ -277,7 +275,6 @@ See [Usage Examples](../reference/cli.md#examples).
manage existing resources.
* All `viewer` tools
* `create_database`
* `create_backup`
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
all resources.
* All `editor` and `viewer` tools
@@ -293,7 +290,6 @@ See [Usage Examples](../reference/cli.md#examples).
* `create_user`: Creates a new user in a Cloud SQL instance.
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
* `clone_instance`: Creates a clone for an existing Cloud SQL for PostgreSQL instance.
* `create_backup`: Creates a backup on a Cloud SQL instance.
## Cloud SQL for SQL Server
@@ -340,7 +336,6 @@ See [Usage Examples](../reference/cli.md#examples).
manage existing resources.
* All `viewer` tools
* `create_database`
* `create_backup`
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
all resources.
* All `editor` and `viewer` tools
@@ -356,7 +351,6 @@ See [Usage Examples](../reference/cli.md#examples).
* `create_user`: Creates a new user in a Cloud SQL instance.
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
* `clone_instance`: Creates a clone for an existing Cloud SQL for SQL Server instance.
* `create_backup`: Creates a backup on a Cloud SQL instance.
## Dataplex

View File

@@ -134,7 +134,6 @@ sources:
# scopes: # Optional: List of OAuth scopes to request.
# - "https://www.googleapis.com/auth/bigquery"
# - "https://www.googleapis.com/auth/drive.readonly"
# maxQueryResultRows: 50 # Optional: Limits the number of rows returned by queries. Defaults to 50.
```
Initialize a BigQuery source that uses the client's access token:
@@ -154,7 +153,6 @@ sources:
# scopes: # Optional: List of OAuth scopes to request.
# - "https://www.googleapis.com/auth/bigquery"
# - "https://www.googleapis.com/auth/drive.readonly"
# maxQueryResultRows: 50 # Optional: Limits the number of rows returned by queries. Defaults to 50.
```
## Reference
@@ -169,4 +167,3 @@ sources:
| useClientOAuth | bool | false | If true, forwards the client's OAuth access token from the "Authorization" header to downstream queries. **Note:** This cannot be used with `writeMode: protected`. |
| scopes | []string | false | A list of OAuth 2.0 scopes to use for the credentials. If not provided, default scopes are used. |
| impersonateServiceAccount | string | false | Service account email to impersonate when making BigQuery and Dataplex API calls. The authenticated principal must have the `roles/iam.serviceAccountTokenCreator` role on the target service account. [Learn More](https://cloud.google.com/iam/docs/service-account-impersonation) |
| maxQueryResultRows | int | false | The maximum number of rows to return from a query. Defaults to 50. |

View File

@@ -91,8 +91,8 @@ visible to the LLM.
https://cloud.google.com/alloydb/docs/parameterized-secure-views-overview
{{< notice tip >}} Make sure to enable the `parameterized_views` extension
to utilize PSV feature (`nlConfigParameters`) with this tool. You can do so by
running this command in the AlloyDB studio:
before running this tool. You can do so by running this command in the AlloyDB
studio:
```sql
CREATE EXTENSION IF NOT EXISTS parameterized_views;

View File

@@ -1,45 +0,0 @@
---
title: cloud-sql-create-backup
type: docs
weight: 10
description: "Creates a backup on a Cloud SQL instance."
---
The `cloud-sql-create-backup` tool creates an on-demand backup on a Cloud SQL instance using the Cloud SQL Admin API.
{{< notice info dd>}}
This tool uses a `source` of kind `cloud-sql-admin`.
{{< /notice >}}
## Examples
Basic backup creation (current state)
```yaml
tools:
backup-creation-basic:
kind: cloud-sql-create-backup
source: cloud-sql-admin-source
description: "Creates a backup on the given Cloud SQL instance."
```
## Reference
### Tool Configuration
| **field** | **type** | **required** | **description** |
| -------------- | :------: | :----------: | ------------------------------------------------------------- |
| kind | string | true | Must be "cloud-sql-create-backup". |
| source | string | true | The name of the `cloud-sql-admin` source to use. |
| description | string | false | A description of the tool. |
### Tool Inputs
| **parameter** | **type** | **required** | **description** |
| -------------------------- | :------: | :----------: | ------------------------------------------------------------------------------- |
| project | string | true | The project ID. |
| instance | string | true | The name of the instance to take a backup on. Does not include the project ID. |
| location | string | false | (Optional) Location of the backup run. |
| backup_description | string | false | (Optional) The description of this backup run. |
## See Also
- [Cloud SQL Admin API documentation](https://cloud.google.com/sql/docs/mysql/admin-api)
- [Toolbox Cloud SQL tools documentation](../cloudsql)
- [Cloud SQL Backup API documentation](https://cloud.google.com/sql/docs/mysql/backup-recovery/backups)

View File

@@ -1,7 +0,0 @@
---
title: "Neo4j"
type: docs
weight: 1
description: >
How to get started with Toolbox using Neo4j.
---

View File

@@ -1,141 +0,0 @@
---
title: "Quickstart (MCP with Neo4j)"
type: docs
weight: 1
description: >
How to get started running Toolbox with MCP Inspector and Neo4j as the source.
---
## Overview
[Model Context Protocol](https://modelcontextprotocol.io) is an open protocol that standardizes how applications provide context to LLMs. Check out this page on how to [connect to Toolbox via MCP](../../how-to/connect_via_mcp.md).
## Step 1: Set up your Neo4j Database and Data
In this section, you'll set up a database and populate it with sample data for a movies-related agent. This guide assumes you have a running Neo4j instance, either locally or in the cloud.
. **Populate the database with data.**
To make this quickstart straightforward, we'll use the built-in Movies dataset available in Neo4j.
. In your Neo4j Browser, run the following command to create and populate the database:
+
```cypher
:play movies
````
. Follow the instructions to load the data. This will create a graph with `Movie`, `Person`, and `Actor` nodes and their relationships.
## Step 2: Install and configure Toolbox
In this section, we will install the MCP Toolbox, configure our tools in a `tools.yaml` file, and then run the Toolbox server.
. **Install the Toolbox binary.**
The simplest way to get started is to download the latest binary for your operating system.
. Download the latest version of Toolbox as a binary:
\+
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O [https://storage.googleapis.com/genai-toolbox/v0.16.0/$OS/toolbox](https://storage.googleapis.com/genai-toolbox/v0.16.0/$OS/toolbox)
```
+
. Make the binary executable:
\+
```bash
chmod +x toolbox
```
. **Create the `tools.yaml` file.**
This file defines your Neo4j source and the specific tools that will be exposed to your AI agent.
\+
{{\< notice tip \>}}
Authentication for the Neo4j source uses standard username and password fields. For production use, it is highly recommended to use environment variables for sensitive information like passwords.
{{\< /notice \>}}
\+
Write the following into a `tools.yaml` file:
\+
```yaml
sources:
my-neo4j-source:
kind: neo4j
uri: bolt://localhost:7687
user: neo4j
password: my-password # Replace with your actual password
tools:
search-movies-by-actor:
kind: neo4j-cypher
source: my-neo4j-source
description: "Searches for movies an actor has appeared in based on their name. Useful for questions like 'What movies has Tom Hanks been in?'"
parameters:
- name: actor_name
type: string
description: The full name of the actor to search for.
statement: |
MATCH (p:Person {name: $actor_name}) -[:ACTED_IN]-> (m:Movie)
RETURN m.title AS title, m.year AS year, m.genre AS genre
get-actor-for-movie:
kind: neo4j-cypher
source: my-neo4j-source
description: "Finds the actors who starred in a specific movie. Useful for questions like 'Who acted in Inception?'"
parameters:
- name: movie_title
type: string
description: The exact title of the movie.
statement: |
MATCH (p:Person) -[:ACTED_IN]-> (m:Movie {title: $movie_title})
RETURN p.name AS actor
```
. **Start the Toolbox server.**
Run the Toolbox server, pointing to the `tools.yaml` file you created earlier.
\+
```bash
./toolbox --tools-file "tools.yaml"
```
## Step 3: Connect to MCP Inspector
. **Run the MCP Inspector:**
\+
```bash
npx @modelcontextprotocol/inspector
```
. Type `y` when it asks to install the inspector package.
. It should show the following when the MCP Inspector is up and running (please take note of `<YOUR_SESSION_TOKEN>`):
\+
```bash
Starting MCP inspector...
⚙️ Proxy server listening on localhost:6277
🔑 Session token: <YOUR_SESSION_TOKEN>
Use this token to authenticate requests or set DANGEROUSLY_OMIT_AUTH=true to disable auth
🚀 MCP Inspector is up and running at:
http://localhost:6274/?MCP_PROXY_AUTH_TOKEN=<YOUR_SESSION_TOKEN>
```
1. Open the above link in your browser.
1. For `Transport Type`, select `Streamable HTTP`.
1. For `URL`, type in `http://127.0.0.1:5000/mcp`.
1. For `Configuration` -\> `Proxy Session Token`, make sure `<YOUR_SESSION_TOKEN>` is present.
1. Click `Connect`.
1. Select `List Tools`, you will see a list of tools configured in `tools.yaml`.
1. Test out your tools here\!

View File

@@ -21,13 +21,13 @@ import (
// AuthServiceConfig is the interface for configuring authentication services.
type AuthServiceConfig interface {
AuthServiceConfigType() string
AuthServiceConfigKind() string
Initialize() (AuthService, error)
}
// AuthService is the interface for authentication services.
type AuthService interface {
AuthServiceType() string
AuthServiceKind() string
GetName() string
GetClaimsFromHeader(context.Context, http.Header) (map[string]any, error)
ToConfig() AuthServiceConfig

View File

@@ -23,7 +23,7 @@ import (
"google.golang.org/api/idtoken"
)
const AuthServiceType string = "google"
const AuthServiceKind string = "google"
// validate interface
var _ auth.AuthServiceConfig = Config{}
@@ -31,13 +31,13 @@ var _ auth.AuthServiceConfig = Config{}
// Auth service configuration
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
ClientID string `yaml:"clientId" validate:"required"`
}
// Returns the auth service type
func (cfg Config) AuthServiceConfigType() string {
return AuthServiceType
// Returns the auth service kind
func (cfg Config) AuthServiceConfigKind() string {
return AuthServiceKind
}
// Initialize a Google auth service
@@ -55,9 +55,9 @@ type AuthService struct {
Config
}
// Returns the auth service type
func (a AuthService) AuthServiceType() string {
return AuthServiceType
// Returns the auth service kind
func (a AuthService) AuthServiceKind() string {
return AuthServiceKind
}
func (a AuthService) ToConfig() auth.AuthServiceConfig {

View File

@@ -22,12 +22,12 @@ import (
// EmbeddingModelConfig is the interface for configuring embedding models.
type EmbeddingModelConfig interface {
EmbeddingModelConfigType() string
EmbeddingModelConfigKind() string
Initialize(context.Context) (EmbeddingModel, error)
}
type EmbeddingModel interface {
EmbeddingModelType() string
EmbeddingModelKind() string
ToConfig() EmbeddingModelConfig
EmbedParameters(context.Context, []string) ([][]float32, error)
}

View File

@@ -23,22 +23,22 @@ import (
"google.golang.org/genai"
)
const EmbeddingModelType string = "gemini"
const EmbeddingModelKind string = "gemini"
// validate interface
var _ embeddingmodels.EmbeddingModelConfig = Config{}
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Model string `yaml:"model" validate:"required"`
ApiKey string `yaml:"apiKey"`
Dimension int32 `yaml:"dimension"`
}
// Returns the embedding model type
func (cfg Config) EmbeddingModelConfigType() string {
return EmbeddingModelType
// Returns the embedding model kind
func (cfg Config) EmbeddingModelConfigKind() string {
return EmbeddingModelKind
}
// Initialize a Gemini embedding model
@@ -69,9 +69,9 @@ type EmbeddingModel struct {
Config
}
// Returns the embedding model type
func (m EmbeddingModel) EmbeddingModelType() string {
return EmbeddingModelType
// Returns the embedding model kind
func (m EmbeddingModel) EmbeddingModelKind() string {
return EmbeddingModelKind
}
func (m EmbeddingModel) ToConfig() embeddingmodels.EmbeddingModelConfig {

View File

@@ -42,7 +42,7 @@ func TestParseFromYamlGemini(t *testing.T) {
want: map[string]embeddingmodels.EmbeddingModelConfig{
"my-gemini-model": gemini.Config{
Name: "my-gemini-model",
Type: gemini.EmbeddingModelType,
Kind: gemini.EmbeddingModelKind,
Model: "text-embedding-004",
},
},
@@ -60,7 +60,7 @@ func TestParseFromYamlGemini(t *testing.T) {
want: map[string]embeddingmodels.EmbeddingModelConfig{
"complex-gemini": gemini.Config{
Name: "complex-gemini",
Type: gemini.EmbeddingModelType,
Kind: gemini.EmbeddingModelKind,
Model: "text-embedding-004",
ApiKey: "test-api-key",
Dimension: 768,

View File

@@ -19,7 +19,6 @@ sources:
location: ${BIGQUERY_LOCATION:}
useClientOAuth: ${BIGQUERY_USE_CLIENT_OAUTH:false}
scopes: ${BIGQUERY_SCOPES:}
maxQueryResultRows: ${BIGQUERY_MAX_QUERY_RESULT_ROWS:50}
tools:
analyze_contribution:

View File

@@ -43,9 +43,6 @@ tools:
clone_instance:
kind: cloud-sql-clone-instance
source: cloud-sql-admin-source
create_backup:
kind: cloud-sql-create-backup
source: cloud-sql-admin-source
toolsets:
cloud_sql_mssql_admin_tools:
@@ -57,4 +54,3 @@ toolsets:
- create_user
- wait_for_operation
- clone_instance
- create_backup

View File

@@ -43,9 +43,6 @@ tools:
clone_instance:
kind: cloud-sql-clone-instance
source: cloud-sql-admin-source
create_backup:
kind: cloud-sql-create-backup
source: cloud-sql-admin-source
toolsets:
cloud_sql_mysql_admin_tools:
@@ -57,4 +54,3 @@ toolsets:
- create_user
- wait_for_operation
- clone_instance
- create_backup

View File

@@ -46,9 +46,6 @@ tools:
postgres_upgrade_precheck:
kind: postgres-upgrade-precheck
source: cloud-sql-admin-source
create_backup:
kind: cloud-sql-create-backup
source: cloud-sql-admin-source
toolsets:
cloud_sql_postgres_admin_tools:
@@ -61,4 +58,3 @@ toolsets:
- wait_for_operation
- postgres_upgrade_precheck
- clone_instance
- create_backup

View File

@@ -27,10 +27,10 @@ type Message = prompts.Message
const kind = "custom"
// init registers this prompt type with the prompt framework.
// init registers this prompt kind with the prompt framework.
func init() {
if !prompts.Register(kind, newConfig) {
panic(fmt.Sprintf("prompt type %q already registered", kind))
panic(fmt.Sprintf("prompt kind %q already registered", kind))
}
}
@@ -56,7 +56,7 @@ type Config struct {
var _ prompts.PromptConfig = Config{}
var _ prompts.Prompt = Prompt{}
func (c Config) PromptConfigType() string {
func (c Config) PromptConfigKind() string {
return kind
}

View File

@@ -50,8 +50,8 @@ func TestConfig(t *testing.T) {
if p == nil {
t.Fatal("Initialize() returned a nil prompt")
}
if cfg.PromptConfigType() != "custom" {
t.Errorf("PromptConfigType() = %q, want %q", cfg.PromptConfigType(), "custom")
if cfg.PromptConfigKind() != "custom" {
t.Errorf("PromptConfigKind() = %q, want %q", cfg.PromptConfigKind(), "custom")
}
t.Run("Manifest", func(t *testing.T) {

View File

@@ -52,7 +52,7 @@ func DecodeConfig(ctx context.Context, kind, name string, decoder *yaml.Decoder)
}
if !found {
return nil, fmt.Errorf("unknown prompt type: %q", kind)
return nil, fmt.Errorf("unknown prompt kind: %q", kind)
}
promptConfig, err := factory(ctx, name, decoder)
@@ -63,7 +63,7 @@ func DecodeConfig(ctx context.Context, kind, name string, decoder *yaml.Decoder)
}
type PromptConfig interface {
PromptConfigType() string
PromptConfigKind() string
Initialize() (Prompt, error)
}

View File

@@ -29,16 +29,16 @@ import (
type mockPromptConfig struct {
name string
Type string
kind string
}
func (m *mockPromptConfig) PromptConfigType() string { return m.Type }
func (m *mockPromptConfig) PromptConfigKind() string { return m.kind }
func (m *mockPromptConfig) Initialize() (prompts.Prompt, error) { return nil, nil }
var errMockFactory = errors.New("mock factory error")
func mockFactory(ctx context.Context, name string, decoder *yaml.Decoder) (prompts.PromptConfig, error) {
return &mockPromptConfig{name: name, Type: "mockType"}, nil
return &mockPromptConfig{name: name, kind: "mockKind"}, nil
}
func mockErrorFactory(ctx context.Context, name string, decoder *yaml.Decoder) (prompts.PromptConfig, error) {
@@ -50,7 +50,7 @@ func TestRegistry(t *testing.T) {
ctx := context.Background()
t.Run("RegisterAndDecodeSuccess", func(t *testing.T) {
kind := "testTypeSuccess"
kind := "testKindSuccess"
if !prompts.Register(kind, mockFactory) {
t.Fatal("expected registration to succeed")
}
@@ -69,19 +69,19 @@ func TestRegistry(t *testing.T) {
}
})
t.Run("DecodeUnknownType", func(t *testing.T) {
t.Run("DecodeUnknownKind", func(t *testing.T) {
decoder := yaml.NewDecoder(strings.NewReader(""))
_, err := prompts.DecodeConfig(ctx, "unregisteredType", "testPrompt", decoder)
_, err := prompts.DecodeConfig(ctx, "unregisteredKind", "testPrompt", decoder)
if err == nil {
t.Fatal("expected an error for unknown kind, but got nil")
}
if !strings.Contains(err.Error(), "unknown prompt type") {
t.Errorf("expected error to contain 'unknown prompt type', but got: %v", err)
if !strings.Contains(err.Error(), "unknown prompt kind") {
t.Errorf("expected error to contain 'unknown prompt kind', but got: %v", err)
}
})
t.Run("FactoryReturnsError", func(t *testing.T) {
kind := "testTypeError"
kind := "testKindError"
if !prompts.Register(kind, mockErrorFactory) {
t.Fatal("expected registration to succeed")
}
@@ -105,8 +105,8 @@ func TestRegistry(t *testing.T) {
if config == nil {
t.Fatal("expected a non-nil config for default kind")
}
if config.PromptConfigType() != "custom" {
t.Errorf("expected default kind to be 'custom', but got %q", config.PromptConfigType())
if config.PromptConfigKind() != "custom" {
t.Errorf("expected default kind to be 'custom', but got %q", config.PromptConfigKind())
}
})
}

View File

@@ -198,7 +198,7 @@ func (c *AuthServiceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(i
return fmt.Errorf("error creating decoder: %w", err)
}
switch kind {
case google.AuthServiceType:
case google.AuthServiceKind:
actual := google.Config{Name: name}
if err := dec.DecodeContext(ctx, &actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", kind, err)
@@ -242,7 +242,7 @@ func (c *EmbeddingModelConfigs) UnmarshalYAML(ctx context.Context, unmarshal fun
return fmt.Errorf("error creating decoder: %w", err)
}
switch kind {
case gemini.EmbeddingModelType:
case gemini.EmbeddingModelKind:
actual := gemini.Config{Name: name}
if err := dec.DecodeContext(ctx, &actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", kind, err)

View File

@@ -32,7 +32,7 @@ func TestUpdateServer(t *testing.T) {
"example-source": &alloydbpg.Source{
Config: alloydbpg.Config{
Name: "example-alloydb-source",
Type: "alloydb-postgres",
Kind: "alloydb-postgres",
},
},
}
@@ -92,7 +92,7 @@ func TestUpdateServer(t *testing.T) {
"example-source2": &alloydbpg.Source{
Config: alloydbpg.Config{
Name: "example-alloydb-source2",
Type: "alloydb-postgres",
Kind: "alloydb-postgres",
},
},
}

View File

@@ -82,7 +82,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
childCtx, span := instrumentation.Tracer.Start(
ctx,
"toolbox/server/source/init",
trace.WithAttributes(attribute.String("source_type", sc.SourceConfigType())),
trace.WithAttributes(attribute.String("source_kind", sc.SourceConfigKind())),
trace.WithAttributes(attribute.String("source_name", name)),
)
defer span.End()
@@ -110,7 +110,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
_, span := instrumentation.Tracer.Start(
ctx,
"toolbox/server/auth/init",
trace.WithAttributes(attribute.String("auth_type", sc.AuthServiceConfigType())),
trace.WithAttributes(attribute.String("auth_kind", sc.AuthServiceConfigKind())),
trace.WithAttributes(attribute.String("auth_name", name)),
)
defer span.End()
@@ -138,7 +138,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
_, span := instrumentation.Tracer.Start(
ctx,
"toolbox/server/embeddingmodel/init",
trace.WithAttributes(attribute.String("model_type", ec.EmbeddingModelConfigType())),
trace.WithAttributes(attribute.String("model_kind", ec.EmbeddingModelConfigKind())),
trace.WithAttributes(attribute.String("model_name", name)),
)
defer span.End()
@@ -166,7 +166,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
_, span := instrumentation.Tracer.Start(
ctx,
"toolbox/server/tool/init",
trace.WithAttributes(attribute.String("tool_type", tc.ToolConfigType())),
trace.WithAttributes(attribute.String("tool_kind", tc.ToolConfigKind())),
trace.WithAttributes(attribute.String("tool_name", name)),
)
defer span.End()
@@ -235,7 +235,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
_, span := instrumentation.Tracer.Start(
ctx,
"toolbox/server/prompt/init",
trace.WithAttributes(attribute.String("prompt_type", pc.PromptConfigType())),
trace.WithAttributes(attribute.String("prompt_kind", pc.PromptConfigKind())),
trace.WithAttributes(attribute.String("prompt_name", name)),
)
defer span.End()

View File

@@ -141,7 +141,7 @@ func TestUpdateServer(t *testing.T) {
"example-source": &alloydbpg.Source{
Config: alloydbpg.Config{
Name: "example-alloydb-source",
Type: "alloydb-postgres",
Kind: "alloydb-postgres",
},
},
}

View File

@@ -32,14 +32,14 @@ import (
"google.golang.org/api/option"
)
const SourceType string = "alloydb-admin"
const SourceKind string = "alloydb-admin"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -53,13 +53,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
DefaultProject string `yaml:"defaultProject"`
UseClientOAuth bool `yaml:"useClientOAuth"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -106,8 +106,8 @@ type Source struct {
Service *alloydbrestapi.Service
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -41,7 +41,7 @@ func TestParseFromYamlAlloyDBAdmin(t *testing.T) {
want: map[string]sources.SourceConfig{
"my-alloydb-admin-instance": alloydbadmin.Config{
Name: "my-alloydb-admin-instance",
Type: alloydbadmin.SourceType,
Kind: alloydbadmin.SourceKind,
UseClientOAuth: false,
},
},
@@ -57,7 +57,7 @@ func TestParseFromYamlAlloyDBAdmin(t *testing.T) {
want: map[string]sources.SourceConfig{
"my-alloydb-admin-instance": alloydbadmin.Config{
Name: "my-alloydb-admin-instance",
Type: alloydbadmin.SourceType,
Kind: alloydbadmin.SourceKind,
UseClientOAuth: true,
},
},

View File

@@ -29,14 +29,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "alloydb-postgres"
const SourceKind string = "alloydb-postgres"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -50,7 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Region string `yaml:"region" validate:"required"`
Cluster string `yaml:"cluster" validate:"required"`
@@ -61,8 +61,8 @@ type Config struct {
Database string `yaml:"database" validate:"required"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -90,8 +90,8 @@ type Source struct {
Pool *pgxpool.Pool
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -183,7 +183,7 @@ func getConnectionConfig(ctx context.Context, user, pass, dbname string) (string
func initAlloyDBPgConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, cluster, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
dsn, useIAM, err := getConnectionConfig(ctx, user, pass, dbname)

View File

@@ -48,7 +48,7 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
want: map[string]sources.SourceConfig{
"my-pg-instance": alloydbpg.Config{
Name: "my-pg-instance",
Type: alloydbpg.SourceType,
Kind: alloydbpg.SourceKind,
Project: "my-project",
Region: "my-region",
Cluster: "my-cluster",
@@ -78,7 +78,7 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
want: map[string]sources.SourceConfig{
"my-pg-instance": alloydbpg.Config{
Name: "my-pg-instance",
Type: alloydbpg.SourceType,
Kind: alloydbpg.SourceKind,
Project: "my-project",
Region: "my-region",
Cluster: "my-cluster",
@@ -108,7 +108,7 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
want: map[string]sources.SourceConfig{
"my-pg-instance": alloydbpg.Config{
Name: "my-pg-instance",
Type: alloydbpg.SourceType,
Kind: alloydbpg.SourceKind,
Project: "my-project",
Region: "my-region",
Cluster: "my-cluster",

View File

@@ -41,7 +41,7 @@ import (
"google.golang.org/api/option"
)
const SourceType string = "bigquery"
const SourceKind string = "bigquery"
// CloudPlatformScope is a broad scope for Google Cloud Platform services.
const CloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
@@ -65,8 +65,8 @@ type BigQuerySessionProvider func(ctx context.Context) (*Session, error)
type DataplexClientCreator func(tokenString string) (*dataplexapi.CatalogClient, error)
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -81,7 +81,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
// BigQuery configs
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Location string `yaml:"location"`
WriteMode string `yaml:"writeMode"`
@@ -89,7 +89,6 @@ type Config struct {
UseClientOAuth bool `yaml:"useClientOAuth"`
ImpersonateServiceAccount string `yaml:"impersonateServiceAccount"`
Scopes StringOrStringSlice `yaml:"scopes"`
MaxQueryResultRows int `yaml:"maxQueryResultRows"`
}
// StringOrStringSlice is a custom type that can unmarshal both a single string
@@ -119,19 +118,15 @@ func (s *StringOrStringSlice) UnmarshalYAML(unmarshal func(any) error) error {
return fmt.Errorf("cannot unmarshal %T into StringOrStringSlice", v)
}
func (r Config) SourceConfigType() string {
func (r Config) SourceConfigKind() string {
// Returns BigQuery source kind
return SourceType
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
if r.WriteMode == "" {
r.WriteMode = WriteModeAllowed
}
if r.MaxQueryResultRows == 0 {
r.MaxQueryResultRows = 50
}
if r.WriteMode == WriteModeProtected && r.UseClientOAuth {
// The protected mode only allows write operations to the session's temporary datasets.
// when using client OAuth, a new session is created every
@@ -155,7 +150,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
Client: client,
RestService: restService,
TokenSource: tokenSource,
MaxQueryResultRows: r.MaxQueryResultRows,
MaxQueryResultRows: 50,
ClientCreator: clientCreator,
}
@@ -302,9 +297,9 @@ type Session struct {
LastUsed time.Time
}
func (s *Source) SourceType() string {
func (s *Source) SourceKind() string {
// Returns BigQuery Google SQL source kind
return SourceType
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -572,7 +567,7 @@ func (s *Source) RunSQL(ctx context.Context, bqClient *bigqueryapi.Client, state
}
var out []any
for s.MaxQueryResultRows <= 0 || len(out) < s.MaxQueryResultRows {
for {
var val []bigqueryapi.Value
err = it.Next(&val)
if err == iterator.Done {
@@ -665,7 +660,7 @@ func initBigQueryConnection(
impersonateServiceAccount string,
scopes []string,
) (*bigqueryapi.Client, *bigqueryrestapi.Service, oauth2.TokenSource, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
userAgent, err := util.UserAgentFromContext(ctx)
@@ -741,7 +736,7 @@ func initBigQueryConnectionWithOAuthToken(
tokenString string,
wantRestService bool,
) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Construct token source
token := &oauth2.Token{
@@ -801,7 +796,7 @@ func initDataplexConnection(
var clientCreator DataplexClientCreator
var err error
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
userAgent, err := util.UserAgentFromContext(ctx)

View File

@@ -21,12 +21,9 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"go.opentelemetry.io/otel/trace/noop"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/util"
)
func TestParseFromYamlBigQuery(t *testing.T) {
@@ -46,7 +43,7 @@ func TestParseFromYamlBigQuery(t *testing.T) {
want: server.SourceConfigs{
"my-instance": bigquery.Config{
Name: "my-instance",
Type: bigquery.SourceType,
Kind: bigquery.SourceKind,
Project: "my-project",
Location: "",
WriteMode: "",
@@ -66,7 +63,7 @@ func TestParseFromYamlBigQuery(t *testing.T) {
want: server.SourceConfigs{
"my-instance": bigquery.Config{
Name: "my-instance",
Type: bigquery.SourceType,
Kind: bigquery.SourceKind,
Project: "my-project",
Location: "asia",
WriteMode: "blocked",
@@ -87,7 +84,7 @@ func TestParseFromYamlBigQuery(t *testing.T) {
want: server.SourceConfigs{
"my-instance": bigquery.Config{
Name: "my-instance",
Type: bigquery.SourceType,
Kind: bigquery.SourceKind,
Project: "my-project",
Location: "us",
UseClientOAuth: true,
@@ -108,7 +105,7 @@ func TestParseFromYamlBigQuery(t *testing.T) {
want: server.SourceConfigs{
"my-instance": bigquery.Config{
Name: "my-instance",
Type: bigquery.SourceType,
Kind: bigquery.SourceKind,
Project: "my-project",
Location: "us",
AllowedDatasets: []string{"my_dataset"},
@@ -128,7 +125,7 @@ func TestParseFromYamlBigQuery(t *testing.T) {
want: server.SourceConfigs{
"my-instance": bigquery.Config{
Name: "my-instance",
Type: bigquery.SourceType,
Kind: bigquery.SourceKind,
Project: "my-project",
Location: "us",
ImpersonateServiceAccount: "service-account@my-project.iam.gserviceaccount.com",
@@ -150,33 +147,13 @@ func TestParseFromYamlBigQuery(t *testing.T) {
want: server.SourceConfigs{
"my-instance": bigquery.Config{
Name: "my-instance",
Type: bigquery.SourceType,
Kind: bigquery.SourceKind,
Project: "my-project",
Location: "us",
Scopes: []string{"https://www.googleapis.com/auth/bigquery", "https://www.googleapis.com/auth/cloud-platform"},
},
},
},
{
desc: "with max query result rows example",
in: `
sources:
my-instance:
kind: bigquery
project: my-project
location: us
maxQueryResultRows: 10
`,
want: server.SourceConfigs{
"my-instance": bigquery.Config{
Name: "my-instance",
Type: bigquery.SourceType,
Project: "my-project",
Location: "us",
MaxQueryResultRows: 10,
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
@@ -243,59 +220,6 @@ func TestFailParseFromYaml(t *testing.T) {
}
}
func TestInitialize_MaxQueryResultRows(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
ctx = util.WithUserAgent(ctx, "test-agent")
tracer := noop.NewTracerProvider().Tracer("")
tcs := []struct {
desc string
cfg bigquery.Config
want int
}{
{
desc: "default value",
cfg: bigquery.Config{
Name: "test-default",
Type: bigquery.SourceType,
Project: "test-project",
UseClientOAuth: true,
},
want: 50,
},
{
desc: "configured value",
cfg: bigquery.Config{
Name: "test-configured",
Type: bigquery.SourceType,
Project: "test-project",
UseClientOAuth: true,
MaxQueryResultRows: 100,
},
want: 100,
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
src, err := tc.cfg.Initialize(ctx, tracer)
if err != nil {
t.Fatalf("Initialize failed: %v", err)
}
bqSrc, ok := src.(*bigquery.Source)
if !ok {
t.Fatalf("Expected *bigquery.Source, got %T", src)
}
if bqSrc.MaxQueryResultRows != tc.want {
t.Errorf("MaxQueryResultRows = %d, want %d", bqSrc.MaxQueryResultRows, tc.want)
}
})
}
}
func TestNormalizeValue(t *testing.T) {
tests := []struct {
name string

View File

@@ -27,14 +27,14 @@ import (
"google.golang.org/api/option"
)
const SourceType string = "bigtable"
const SourceKind string = "bigtable"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -48,13 +48,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Instance string `yaml:"instance" validate:"required"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -77,8 +77,8 @@ type Source struct {
Client *bigtable.Client
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -179,7 +179,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, configParam param
func initBigtableClient(ctx context.Context, tracer trace.Tracer, name, project, instance string) (*bigtable.Client, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Set up Bigtable data operations client.

View File

@@ -43,7 +43,7 @@ func TestParseFromYamlBigtableDb(t *testing.T) {
want: map[string]sources.SourceConfig{
"my-bigtable-instance": bigtable.Config{
Name: "my-bigtable-instance",
Type: bigtable.SourceType,
Kind: bigtable.SourceKind,
Project: "my-project",
Instance: "my-instance",
},

View File

@@ -25,11 +25,11 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "cassandra"
const SourceKind string = "cassandra"
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -43,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Hosts []string `yaml:"hosts" validate:"required"`
Keyspace string `yaml:"keyspace"`
ProtoVersion int `yaml:"protoVersion"`
@@ -68,9 +68,9 @@ func (c Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return s, nil
}
// SourceConfigType implements sources.SourceConfig.
func (c Config) SourceConfigType() string {
return SourceType
// SourceConfigKind implements sources.SourceConfig.
func (c Config) SourceConfigKind() string {
return SourceKind
}
var _ sources.SourceConfig = Config{}
@@ -89,9 +89,9 @@ func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
// SourceType implements sources.Source.
func (s *Source) SourceType() string {
return SourceType
// SourceKind implements sources.Source.
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) RunSQL(ctx context.Context, statement string, params parameters.ParamValues) (any, error) {
@@ -120,7 +120,7 @@ var _ sources.Source = &Source{}
func initCassandraSession(ctx context.Context, tracer trace.Tracer, c Config) (*gocql.Session, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, c.Name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, c.Name)
defer span.End()
// Validate authentication configuration

View File

@@ -43,7 +43,7 @@ func TestParseFromYamlCassandra(t *testing.T) {
want: server.SourceConfigs{
"my-cassandra-instance": cassandra.Config{
Name: "my-cassandra-instance",
Type: cassandra.SourceType,
Kind: cassandra.SourceKind,
Hosts: []string{"my-host1", "my-host2"},
Username: "",
Password: "",
@@ -77,7 +77,7 @@ func TestParseFromYamlCassandra(t *testing.T) {
want: server.SourceConfigs{
"my-cassandra-instance": cassandra.Config{
Name: "my-cassandra-instance",
Type: cassandra.SourceType,
Kind: cassandra.SourceKind,
Hosts: []string{"my-host1", "my-host2"},
Username: "user",
Password: "pass",

View File

@@ -28,14 +28,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "clickhouse"
const SourceKind string = "clickhouse"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -49,7 +49,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
Database string `yaml:"database" validate:"required"`
@@ -59,8 +59,8 @@ type Config struct {
Secure bool `yaml:"secure"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -88,8 +88,8 @@ type Source struct {
Pool *sql.DB
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -174,7 +174,7 @@ func validateConfig(protocol string) error {
func initClickHouseConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, protocol string, secure bool) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
if protocol == "" {

View File

@@ -25,10 +25,10 @@ import (
"go.opentelemetry.io/otel"
)
func TestConfigSourceConfigType(t *testing.T) {
func TestConfigSourceConfigKind(t *testing.T) {
config := Config{}
if config.SourceConfigType() != SourceType {
t.Errorf("Expected %s, got %s", SourceType, config.SourceConfigType())
if config.SourceConfigKind() != SourceKind {
t.Errorf("Expected %s, got %s", SourceKind, config.SourceConfigKind())
}
}
@@ -53,7 +53,7 @@ func TestNewConfig(t *testing.T) {
`,
expected: Config{
Name: "test-clickhouse",
Type: "clickhouse",
Kind: "clickhouse",
Host: "localhost",
Port: "8443",
User: "default",
@@ -75,7 +75,7 @@ func TestNewConfig(t *testing.T) {
`,
expected: Config{
Name: "minimal-clickhouse",
Type: "clickhouse",
Kind: "clickhouse",
Host: "127.0.0.1",
Port: "8123",
User: "testuser",
@@ -100,7 +100,7 @@ func TestNewConfig(t *testing.T) {
`,
expected: Config{
Name: "http-clickhouse",
Type: "clickhouse",
Kind: "clickhouse",
Host: "clickhouse.example.com",
Port: "8123",
User: "analytics",
@@ -125,7 +125,7 @@ func TestNewConfig(t *testing.T) {
`,
expected: Config{
Name: "secure-clickhouse",
Type: "clickhouse",
Kind: "clickhouse",
Host: "secure.clickhouse.io",
Port: "8443",
User: "secureuser",
@@ -196,10 +196,10 @@ func TestNewConfigInvalidYAML(t *testing.T) {
}
}
func TestSource_SourceType(t *testing.T) {
func TestSource_SourceKind(t *testing.T) {
source := &Source{}
if source.SourceType() != SourceType {
t.Errorf("Expected %s, got %s", SourceType, source.SourceType())
if source.SourceKind() != SourceKind {
t.Errorf("Expected %s, got %s", SourceKind, source.SourceKind())
}
}

View File

@@ -29,15 +29,15 @@ import (
"golang.org/x/oauth2/google"
)
const SourceType string = "cloud-gemini-data-analytics"
const SourceKind string = "cloud-gemini-data-analytics"
const Endpoint string = "https://geminidataanalytics.googleapis.com"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -51,13 +51,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
ProjectID string `yaml:"projectId" validate:"required"`
UseClientOAuth bool `yaml:"useClientOAuth"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
// Initialize initializes a Gemini Data Analytics Source instance.
@@ -102,8 +102,8 @@ type Source struct {
userAgent string
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -47,7 +47,7 @@ func TestParseFromYamlCloudGDA(t *testing.T) {
want: map[string]sources.SourceConfig{
"my-gda-instance": cloudgda.Config{
Name: "my-gda-instance",
Type: cloudgda.SourceType,
Kind: cloudgda.SourceKind,
ProjectID: "test-project-id",
UseClientOAuth: false,
},
@@ -65,7 +65,7 @@ func TestParseFromYamlCloudGDA(t *testing.T) {
want: map[string]sources.SourceConfig{
"my-gda-instance": cloudgda.Config{
Name: "my-gda-instance",
Type: cloudgda.SourceType,
Kind: cloudgda.SourceKind,
ProjectID: "another-project",
UseClientOAuth: true,
},
@@ -153,12 +153,12 @@ func TestInitialize(t *testing.T) {
}{
{
desc: "initialize with ADC",
cfg: cloudgda.Config{Name: "test-gda", Type: cloudgda.SourceType, ProjectID: "test-proj"},
cfg: cloudgda.Config{Name: "test-gda", Kind: cloudgda.SourceKind, ProjectID: "test-proj"},
wantClientOAuth: false,
},
{
desc: "initialize with client OAuth",
cfg: cloudgda.Config{Name: "test-gda-oauth", Type: cloudgda.SourceType, ProjectID: "test-proj", UseClientOAuth: true},
cfg: cloudgda.Config{Name: "test-gda-oauth", Kind: cloudgda.SourceKind, ProjectID: "test-proj", UseClientOAuth: true},
wantClientOAuth: true,
},
}

View File

@@ -16,12 +16,8 @@ package cloudhealthcare
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -34,7 +30,7 @@ import (
"google.golang.org/api/option"
)
const SourceType string = "cloud-healthcare"
const SourceKind string = "cloud-healthcare"
// validate interface
var _ sources.SourceConfig = Config{}
@@ -42,8 +38,8 @@ var _ sources.SourceConfig = Config{}
type HealthcareServiceCreator func(tokenString string) (*healthcare.Service, error)
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -58,7 +54,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
// Healthcare configs
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Region string `yaml:"region" validate:"required"`
Dataset string `yaml:"dataset" validate:"required"`
@@ -67,8 +63,8 @@ type Config struct {
UseClientOAuth bool `yaml:"useClientOAuth"`
}
func (c Config) SourceConfigType() string {
return SourceType
func (c Config) SourceConfigKind() string {
return SourceKind
}
func (c Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -144,7 +140,7 @@ func newHealthcareServiceCreator(ctx context.Context, tracer trace.Tracer, name
}
func initHealthcareConnectionWithOAuthToken(ctx context.Context, tracer trace.Tracer, name string, userAgent string, tokenString string) (*healthcare.Service, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Construct token source
token := &oauth2.Token{
@@ -162,7 +158,7 @@ func initHealthcareConnectionWithOAuthToken(ctx context.Context, tracer trace.Tr
}
func initHealthcareConnection(ctx context.Context, tracer trace.Tracer, name string) (*healthcare.Service, oauth2.TokenSource, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
cred, err := google.FindDefaultCredentials(ctx, healthcare.CloudHealthcareScope)
@@ -194,8 +190,8 @@ type Source struct {
allowedDICOMStores map[string]struct{}
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -259,299 +255,3 @@ func (s *Source) IsDICOMStoreAllowed(storeID string) bool {
func (s *Source) UseClientAuthorization() bool {
return s.UseClientOAuth
}
func parseResults(resp *http.Response) (any, error) {
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
var jsonMap map[string]interface{}
if err := json.Unmarshal(respBytes, &jsonMap); err != nil {
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
}
return jsonMap, nil
}
func (s *Source) getService(tokenStr string) (*healthcare.Service, error) {
svc := s.Service()
var err error
// Initialize new service if using user OAuth token
if s.UseClientAuthorization() {
svc, err = s.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
return svc, nil
}
func (s *Source) FHIRFetchPage(ctx context.Context, url, tokenStr string) (any, error) {
var httpClient *http.Client
if s.UseClientAuthorization() {
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tokenStr})
httpClient = oauth2.NewClient(ctx, ts)
} else {
// The source.Service() object holds a client with the default credentials.
// However, the client is not exported, so we have to create a new one.
var err error
httpClient, err = google.DefaultClient(ctx, healthcare.CloudHealthcareScope)
if err != nil {
return nil, fmt.Errorf("failed to create default http client: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create http request: %w", err)
}
req.Header.Set("Accept", "application/fhir+json;charset=utf-8")
resp, err := httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to get fhir page from %q: %w", url, err)
}
defer resp.Body.Close()
return parseResults(resp)
}
func (s *Source) FHIRPatientEverything(storeID, patientID, tokenStr string, opts []googleapi.CallOption) (any, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", s.Project(), s.Region(), s.DatasetID(), storeID, patientID)
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.PatientEverything(name).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to call patient everything for %q: %w", name, err)
}
defer resp.Body.Close()
return parseResults(resp)
}
func (s *Source) FHIRPatientSearch(storeID, tokenStr string, opts []googleapi.CallOption) (any, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.SearchType(name, "Patient", &healthcare.SearchResourcesRequest{ResourceType: "Patient"}).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to search patient resources: %w", err)
}
defer resp.Body.Close()
return parseResults(resp)
}
func (s *Source) GetDataset(tokenStr string) (*healthcare.Dataset, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
}
return dataset, nil
}
func (s *Source) GetFHIRResource(storeID, resType, resID, tokenStr string) (any, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", s.Project(), s.Region(), s.DatasetID(), storeID, resType, resID)
call := svc.Projects.Locations.Datasets.FhirStores.Fhir.Read(name)
call.Header().Set("Content-Type", "application/fhir+json;charset=utf-8")
resp, err := call.Do()
if err != nil {
return nil, fmt.Errorf("failed to get fhir resource %q: %w", name, err)
}
defer resp.Body.Close()
return parseResults(resp)
}
func (s *Source) GetDICOMStore(storeID, tokenStr string) (*healthcare.DicomStore, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.DicomStores.Get(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get DICOM store %q: %w", storeName, err)
}
return store, nil
}
func (s *Source) GetFHIRStore(storeID, tokenStr string) (*healthcare.FhirStore, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.FhirStores.Get(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get FHIR store %q: %w", storeName, err)
}
return store, nil
}
func (s *Source) GetDICOMStoreMetrics(storeID, tokenStr string) (*healthcare.DicomStoreMetrics, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.DicomStores.GetDICOMStoreMetrics(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get metrics for DICOM store %q: %w", storeName, err)
}
return store, nil
}
func (s *Source) GetFHIRStoreMetrics(storeID, tokenStr string) (*healthcare.FhirStoreMetrics, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.FhirStores.GetFHIRStoreMetrics(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get metrics for FHIR store %q: %w", storeName, err)
}
return store, nil
}
func (s *Source) ListDICOMStores(tokenStr string) ([]*healthcare.DicomStore, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
}
var filtered []*healthcare.DicomStore
for _, store := range stores.DicomStores {
if len(s.AllowedDICOMStores()) == 0 {
filtered = append(filtered, store)
continue
}
if len(store.Name) == 0 {
continue
}
parts := strings.Split(store.Name, "/")
if _, ok := s.AllowedDICOMStores()[parts[len(parts)-1]]; ok {
filtered = append(filtered, store)
}
}
return filtered, nil
}
func (s *Source) ListFHIRStores(tokenStr string) ([]*healthcare.FhirStore, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
}
var filtered []*healthcare.FhirStore
for _, store := range stores.FhirStores {
if len(s.AllowedFHIRStores()) == 0 {
filtered = append(filtered, store)
continue
}
if len(store.Name) == 0 {
continue
}
parts := strings.Split(store.Name, "/")
if _, ok := s.AllowedFHIRStores()[parts[len(parts)-1]]; ok {
filtered = append(filtered, store)
}
}
return filtered, nil
}
func (s *Source) RetrieveRenderedDICOMInstance(storeID, study, series, sop string, frame int, tokenStr string) (any, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
dicomWebPath := fmt.Sprintf("studies/%s/series/%s/instances/%s/frames/%d/rendered", study, series, sop, frame)
call := svc.Projects.Locations.Datasets.DicomStores.Studies.Series.Instances.Frames.RetrieveRendered(name, dicomWebPath)
call.Header().Set("Accept", "image/jpeg")
resp, err := call.Do()
if err != nil {
return nil, fmt.Errorf("unable to retrieve dicom instance rendered image: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("RetrieveRendered: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
base64String := base64.StdEncoding.EncodeToString(respBytes)
return base64String, nil
}
func (s *Source) SearchDICOM(toolType, storeID, dicomWebPath, tokenStr string, opts []googleapi.CallOption) (any, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
var resp *http.Response
switch toolType {
case "cloud-healthcare-search-dicom-instances":
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...)
case "cloud-healthcare-search-dicom-series":
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForSeries(name, dicomWebPath).Do(opts...)
case "cloud-healthcare-search-dicom-studies":
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForStudies(name, dicomWebPath).Do(opts...)
default:
return nil, fmt.Errorf("incompatible tool type: %s", toolType)
}
if err != nil {
return nil, fmt.Errorf("failed to search dicom series: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
if len(respBytes) == 0 {
return []interface{}{}, nil
}
var result []interface{}
if err := json.Unmarshal(respBytes, &result); err != nil {
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
}
return result, nil
}

View File

@@ -43,7 +43,7 @@ func TestParseFromYamlCloudHealthcare(t *testing.T) {
want: server.SourceConfigs{
"my-instance": cloudhealthcare.Config{
Name: "my-instance",
Type: cloudhealthcare.SourceType,
Kind: cloudhealthcare.SourceKind,
Project: "my-project",
Region: "us-central1",
Dataset: "my-dataset",
@@ -65,7 +65,7 @@ func TestParseFromYamlCloudHealthcare(t *testing.T) {
want: server.SourceConfigs{
"my-instance": cloudhealthcare.Config{
Name: "my-instance",
Type: cloudhealthcare.SourceType,
Kind: cloudhealthcare.SourceKind,
Project: "my-project",
Region: "us",
Dataset: "my-dataset",
@@ -91,7 +91,7 @@ func TestParseFromYamlCloudHealthcare(t *testing.T) {
want: server.SourceConfigs{
"my-instance": cloudhealthcare.Config{
Name: "my-instance",
Type: cloudhealthcare.SourceType,
Kind: cloudhealthcare.SourceKind,
Project: "my-project",
Region: "us",
Dataset: "my-dataset",

View File

@@ -29,14 +29,14 @@ import (
monitoring "google.golang.org/api/monitoring/v3"
)
const SourceType string = "cloud-monitoring"
const SourceKind string = "cloud-monitoring"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -50,12 +50,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
UseClientOAuth bool `yaml:"useClientOAuth"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
// Initialize initializes a Cloud Monitoring Source instance.
@@ -99,8 +99,8 @@ type Source struct {
userAgent string
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -42,7 +42,7 @@ func TestParseFromYamlCloudMonitoring(t *testing.T) {
want: map[string]sources.SourceConfig{
"my-cloud-monitoring-instance": cloudmonitoring.Config{
Name: "my-cloud-monitoring-instance",
Type: cloudmonitoring.SourceType,
Kind: cloudmonitoring.SourceKind,
UseClientOAuth: false,
},
},
@@ -58,7 +58,7 @@ func TestParseFromYamlCloudMonitoring(t *testing.T) {
want: map[string]sources.SourceConfig{
"my-cloud-monitoring-instance": cloudmonitoring.Config{
Name: "my-cloud-monitoring-instance",
Type: cloudmonitoring.SourceType,
Kind: cloudmonitoring.SourceKind,
UseClientOAuth: true,
},
},

View File

@@ -34,7 +34,7 @@ import (
sqladmin "google.golang.org/api/sqladmin/v1"
)
const SourceType string = "cloud-sql-admin"
const SourceKind string = "cloud-sql-admin"
var targetLinkRegex = regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/databases/([^/]+)`)
@@ -42,8 +42,8 @@ var targetLinkRegex = regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/da
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -57,13 +57,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
DefaultProject string `yaml:"defaultProject"`
UseClientOAuth bool `yaml:"useClientOAuth"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
// Initialize initializes a CloudSQL Admin Source instance.
@@ -110,8 +110,8 @@ type Source struct {
Service *sqladmin.Service
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -352,28 +352,6 @@ func (s *Source) GetWaitForOperations(ctx context.Context, service *sqladmin.Ser
return nil, nil
}
func (s *Source) InsertBackupRun(ctx context.Context, project, instance, location, backupDescription, accessToken string) (any, error) {
backupRun := &sqladmin.BackupRun{}
if location != "" {
backupRun.Location = location
}
if backupDescription != "" {
backupRun.Description = backupDescription
}
service, err := s.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
resp, err := service.BackupRuns.Insert(project, instance, backupRun).Do()
if err != nil {
return nil, fmt.Errorf("error creating backup: %w", err)
}
return resp, nil
}
func generateCloudSQLConnectionMessage(ctx context.Context, source *Source, logger log.Logger, opResponse map[string]any, connectionMessageTemplate string) (string, bool) {
operationType, ok := opResponse["operationType"].(string)
if !ok || operationType != "CREATE_DATABASE" {

View File

@@ -42,7 +42,7 @@ func TestParseFromYamlCloudSQLAdmin(t *testing.T) {
want: map[string]sources.SourceConfig{
"my-cloud-sql-admin-instance": cloudsqladmin.Config{
Name: "my-cloud-sql-admin-instance",
Type: cloudsqladmin.SourceType,
Kind: cloudsqladmin.SourceKind,
UseClientOAuth: false,
},
},
@@ -58,7 +58,7 @@ func TestParseFromYamlCloudSQLAdmin(t *testing.T) {
want: map[string]sources.SourceConfig{
"my-cloud-sql-admin-instance": cloudsqladmin.Config{
Name: "my-cloud-sql-admin-instance",
Type: cloudsqladmin.SourceType,
Kind: cloudsqladmin.SourceKind,
UseClientOAuth: true,
},
},

View File

@@ -29,14 +29,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "cloud-sql-mssql"
const SourceKind string = "cloud-sql-mssql"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -51,7 +51,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
// Cloud SQL MSSQL configs
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Region string `yaml:"region" validate:"required"`
Instance string `yaml:"instance" validate:"required"`
@@ -62,9 +62,9 @@ type Config struct {
Database string `yaml:"database" validate:"required"`
}
func (r Config) SourceConfigType() string {
func (r Config) SourceConfigKind() string {
// Returns Cloud SQL MSSQL source kind
return SourceType
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -94,9 +94,9 @@ type Source struct {
Db *sql.DB
}
func (s *Source) SourceType() string {
func (s *Source) SourceKind() string {
// Returns Cloud SQL MSSQL source kind
return SourceType
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -152,7 +152,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
userAgent, err := util.UserAgentFromContext(ctx)

View File

@@ -46,7 +46,7 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
want: server.SourceConfigs{
"my-instance": cloudsqlmssql.Config{
Name: "my-instance",
Type: cloudsqlmssql.SourceType,
Kind: cloudsqlmssql.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -74,7 +74,7 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
want: server.SourceConfigs{
"my-instance": cloudsqlmssql.Config{
Name: "my-instance",
Type: cloudsqlmssql.SourceType,
Kind: cloudsqlmssql.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -102,7 +102,7 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
want: server.SourceConfigs{
"my-instance": cloudsqlmssql.Config{
Name: "my-instance",
Type: cloudsqlmssql.SourceType,
Kind: cloudsqlmssql.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",

View File

@@ -30,14 +30,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "cloud-sql-mysql"
const SourceKind string = "cloud-sql-mysql"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -51,7 +51,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Region string `yaml:"region" validate:"required"`
Instance string `yaml:"instance" validate:"required"`
@@ -61,8 +61,8 @@ type Config struct {
Database string `yaml:"database" validate:"required"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -90,8 +90,8 @@ type Source struct {
Pool *sql.DB
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -184,7 +184,7 @@ func getConnectionConfig(ctx context.Context, user, pass string) (string, string
func initCloudSQLMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Configure the driver to connect to the database

View File

@@ -46,7 +46,7 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
want: server.SourceConfigs{
"my-mysql-instance": cloudsqlmysql.Config{
Name: "my-mysql-instance",
Type: cloudsqlmysql.SourceType,
Kind: cloudsqlmysql.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -74,7 +74,7 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
want: server.SourceConfigs{
"my-mysql-instance": cloudsqlmysql.Config{
Name: "my-mysql-instance",
Type: cloudsqlmysql.SourceType,
Kind: cloudsqlmysql.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -102,7 +102,7 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
want: server.SourceConfigs{
"my-mysql-instance": cloudsqlmysql.Config{
Name: "my-mysql-instance",
Type: cloudsqlmysql.SourceType,
Kind: cloudsqlmysql.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -130,7 +130,7 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
want: server.SourceConfigs{
"my-mysql-instance": cloudsqlmysql.Config{
Name: "my-mysql-instance",
Type: cloudsqlmysql.SourceType,
Kind: cloudsqlmysql.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",

View File

@@ -28,14 +28,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "cloud-sql-postgres"
const SourceKind string = "cloud-sql-postgres"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -49,7 +49,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Region string `yaml:"region" validate:"required"`
Instance string `yaml:"instance" validate:"required"`
@@ -59,8 +59,8 @@ type Config struct {
Password string `yaml:"password"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -88,8 +88,8 @@ type Source struct {
Pool *pgxpool.Pool
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -162,7 +162,7 @@ func getConnectionConfig(ctx context.Context, user, pass, dbname string) (string
func initCloudSQLPgConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Configure the driver to connect to the database

View File

@@ -46,7 +46,7 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
want: server.SourceConfigs{
"my-pg-instance": cloudsqlpg.Config{
Name: "my-pg-instance",
Type: cloudsqlpg.SourceType,
Kind: cloudsqlpg.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -74,7 +74,7 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
want: server.SourceConfigs{
"my-pg-instance": cloudsqlpg.Config{
Name: "my-pg-instance",
Type: cloudsqlpg.SourceType,
Kind: cloudsqlpg.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -102,7 +102,7 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
want: server.SourceConfigs{
"my-pg-instance": cloudsqlpg.Config{
Name: "my-pg-instance",
Type: cloudsqlpg.SourceType,
Kind: cloudsqlpg.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -130,7 +130,7 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
want: server.SourceConfigs{
"my-pg-instance": cloudsqlpg.Config{
Name: "my-pg-instance",
Type: cloudsqlpg.SourceType,
Kind: cloudsqlpg.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",

View File

@@ -29,14 +29,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "couchbase"
const SourceKind string = "couchbase"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -50,7 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
ConnectionString string `yaml:"connectionString" validate:"required"`
Bucket string `yaml:"bucket" validate:"required"`
Scope string `yaml:"scope" validate:"required"`
@@ -66,8 +66,8 @@ type Config struct {
QueryScanConsistency uint `yaml:"queryScanConsistency"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -96,8 +96,8 @@ type Source struct {
Scope *gocb.Scope
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -45,7 +45,7 @@ func TestParseFromYamlCouchbase(t *testing.T) {
want: server.SourceConfigs{
"my-couchbase-instance": couchbase.Config{
Name: "my-couchbase-instance",
Type: couchbase.SourceType,
Kind: couchbase.SourceKind,
ConnectionString: "localhost",
Username: "Administrator",
Password: "password",
@@ -74,7 +74,7 @@ func TestParseFromYamlCouchbase(t *testing.T) {
want: server.SourceConfigs{
"my-couchbase-instance": couchbase.Config{
Name: "my-couchbase-instance",
Type: couchbase.SourceType,
Kind: couchbase.SourceKind,
ConnectionString: "couchbases://localhost",
Bucket: "travel-sample",
Scope: "inventory",

View File

@@ -19,8 +19,6 @@ import (
"fmt"
dataplexapi "cloud.google.com/go/dataplex/apiv1"
"cloud.google.com/go/dataplex/apiv1/dataplexpb"
"github.com/cenkalti/backoff/v5"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
@@ -29,14 +27,14 @@ import (
"google.golang.org/api/option"
)
const SourceType string = "dataplex"
const SourceKind string = "dataplex"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -51,13 +49,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
// Dataplex configs
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
}
func (r Config) SourceConfigType() string {
func (r Config) SourceConfigKind() string {
// Returns Dataplex source kind
return SourceType
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -81,9 +79,9 @@ type Source struct {
Client *dataplexapi.CatalogClient
}
func (s *Source) SourceType() string {
func (s *Source) SourceKind() string {
// Returns Dataplex source kind
return SourceType
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -104,7 +102,7 @@ func initDataplexConnection(
name string,
project string,
) (*dataplexapi.CatalogClient, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
cred, err := google.FindDefaultCredentials(ctx)
@@ -123,101 +121,3 @@ func initDataplexConnection(
}
return client, nil
}
func (s *Source) LookupEntry(ctx context.Context, name string, view int, aspectTypes []string, entry string) (*dataplexpb.Entry, error) {
viewMap := map[int]dataplexpb.EntryView{
1: dataplexpb.EntryView_BASIC,
2: dataplexpb.EntryView_FULL,
3: dataplexpb.EntryView_CUSTOM,
4: dataplexpb.EntryView_ALL,
}
req := &dataplexpb.LookupEntryRequest{
Name: name,
View: viewMap[view],
AspectTypes: aspectTypes,
Entry: entry,
}
result, err := s.CatalogClient().LookupEntry(ctx, req)
if err != nil {
return nil, err
}
return result, nil
}
func (s *Source) searchRequest(ctx context.Context, query string, pageSize int, orderBy string) (*dataplexapi.SearchEntriesResultIterator, error) {
// Create SearchEntriesRequest with the provided parameters
req := &dataplexpb.SearchEntriesRequest{
Query: query,
Name: fmt.Sprintf("projects/%s/locations/global", s.ProjectID()),
PageSize: int32(pageSize),
OrderBy: orderBy,
SemanticSearch: true,
}
// Perform the search using the CatalogClient - this will return an iterator
it := s.CatalogClient().SearchEntries(ctx, req)
if it == nil {
return nil, fmt.Errorf("failed to create search entries iterator for project %q", s.ProjectID())
}
return it, nil
}
func (s *Source) SearchAspectTypes(ctx context.Context, query string, pageSize int, orderBy string) ([]*dataplexpb.AspectType, error) {
q := query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype"
it, err := s.searchRequest(ctx, q, pageSize, orderBy)
if err != nil {
return nil, err
}
// Iterate through the search results and call GetAspectType for each result using the resource name
var results []*dataplexpb.AspectType
for {
entry, err := it.Next()
if err != nil {
break
}
// Create an instance of exponential backoff with default values for retrying GetAspectType calls
// InitialInterval, RandomizationFactor, Multiplier, MaxInterval = 500 ms, 0.5, 1.5, 60 s
getAspectBackOff := backoff.NewExponentialBackOff()
resourceName := entry.DataplexEntry.GetEntrySource().Resource
getAspectTypeReq := &dataplexpb.GetAspectTypeRequest{
Name: resourceName,
}
operation := func() (*dataplexpb.AspectType, error) {
aspectType, err := s.CatalogClient().GetAspectType(ctx, getAspectTypeReq)
if err != nil {
return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err)
}
return aspectType, nil
}
// Retry the GetAspectType operation with exponential backoff
aspectType, err := backoff.Retry(ctx, operation, backoff.WithBackOff(getAspectBackOff))
if err != nil {
return nil, fmt.Errorf("failed to get aspect type after retries for entry %q: %w", resourceName, err)
}
results = append(results, aspectType)
}
return results, nil
}
func (s *Source) SearchEntries(ctx context.Context, query string, pageSize int, orderBy string) ([]*dataplexpb.SearchEntriesResult, error) {
it, err := s.searchRequest(ctx, query, pageSize, orderBy)
if err != nil {
return nil, err
}
var results []*dataplexpb.SearchEntriesResult
for {
entry, err := it.Next()
if err != nil {
break
}
results = append(results, entry)
}
return results, nil
}

View File

@@ -41,7 +41,7 @@ func TestParseFromYamlDataplex(t *testing.T) {
want: server.SourceConfigs{
"my-instance": dataplex.Config{
Name: "my-instance",
Type: dataplex.SourceType,
Kind: dataplex.SourceKind,
Project: "my-project",
},
},

View File

@@ -30,14 +30,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "dgraph"
const SourceKind string = "dgraph"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -67,7 +67,7 @@ type DgraphClient struct {
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
DgraphUrl string `yaml:"dgraphUrl" validate:"required"`
User string `yaml:"user"`
Password string `yaml:"password"`
@@ -75,8 +75,8 @@ type Config struct {
ApiKey string `yaml:"apiKey"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -103,8 +103,8 @@ type Source struct {
Client *DgraphClient `yaml:"client"`
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -139,7 +139,7 @@ func (s *Source) RunSQL(statement string, params parameters.ParamValues, isQuery
func initDgraphHttpClient(ctx context.Context, tracer trace.Tracer, r Config) (*DgraphClient, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, r.Name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, r.Name)
defer span.End()
if r.DgraphUrl == "" {

View File

@@ -45,7 +45,7 @@ func TestParseFromYamlDgraph(t *testing.T) {
want: server.SourceConfigs{
"my-dgraph-instance": dgraph.Config{
Name: "my-dgraph-instance",
Type: dgraph.SourceType,
Kind: dgraph.SourceKind,
DgraphUrl: "https://localhost:8080",
ApiKey: "abc123",
Password: "pass@123",
@@ -65,7 +65,7 @@ func TestParseFromYamlDgraph(t *testing.T) {
want: server.SourceConfigs{
"my-dgraph-instance": dgraph.Config{
Name: "my-dgraph-instance",
Type: dgraph.SourceType,
Kind: dgraph.SourceKind,
DgraphUrl: "https://localhost:8080",
},
},

View File

@@ -30,14 +30,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "elasticsearch"
const SourceKind string = "elasticsearch"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -51,15 +51,15 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Addresses []string `yaml:"addresses" validate:"required"`
Username string `yaml:"username"`
Password string `yaml:"password"`
APIKey string `yaml:"apikey"`
}
func (c Config) SourceConfigType() string {
return SourceType
func (c Config) SourceConfigKind() string {
return SourceKind
}
type EsClient interface {
@@ -139,9 +139,9 @@ func (c Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return s, nil
}
// SourceType returns the kind string for this source.
func (s *Source) SourceType() string {
return SourceType
// SourceKind returns the kind string for this source.
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -43,7 +43,7 @@ func TestParseFromYamlElasticsearch(t *testing.T) {
want: server.SourceConfigs{
"my-es-instance": elasticsearch.Config{
Name: "my-es-instance",
Type: elasticsearch.SourceType,
Kind: elasticsearch.SourceKind,
Addresses: []string{"http://localhost:9200"},
APIKey: "somekey",
},

View File

@@ -27,13 +27,13 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
)
const SourceType string = "firebird"
const SourceKind string = "firebird"
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -47,7 +47,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
@@ -55,8 +55,8 @@ type Config struct {
Database string `yaml:"database" validate:"required"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -84,8 +84,8 @@ type Source struct {
Db *sql.DB
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -144,7 +144,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
}
func initFirebirdConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string) (*sql.DB, error) {
_, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
_, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// urlExample := "user:password@host:port/path/to/database.fdb"

View File

@@ -45,7 +45,7 @@ func TestParseFromYamlFirebird(t *testing.T) {
want: server.SourceConfigs{
"my-fdb-instance": firebird.Config{
Name: "my-fdb-instance",
Type: firebird.SourceType,
Kind: firebird.SourceKind,
Host: "my-host",
Port: "my-port",
Database: "my_db",

View File

@@ -16,10 +16,7 @@ package firestore
import (
"context"
"encoding/base64"
"fmt"
"strings"
"time"
"cloud.google.com/go/firestore"
"github.com/goccy/go-yaml"
@@ -28,17 +25,16 @@ import (
"go.opentelemetry.io/otel/trace"
"google.golang.org/api/firebaserules/v1"
"google.golang.org/api/option"
"google.golang.org/genproto/googleapis/type/latlng"
)
const SourceType string = "firestore"
const SourceKind string = "firestore"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -53,14 +49,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
// Firestore configs
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Database string `yaml:"database"` // Optional, defaults to "(default)"
}
func (r Config) SourceConfigType() string {
func (r Config) SourceConfigKind() string {
// Returns Firestore source kind
return SourceType
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -92,9 +88,9 @@ type Source struct {
RulesClient *firebaserules.Service
}
func (s *Source) SourceType() string {
func (s *Source) SourceKind() string {
// Returns Firestore source kind
return SourceType
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -117,476 +113,6 @@ func (s *Source) GetDatabaseId() string {
return s.Database
}
// FirestoreValueToJSON converts a Firestore value to a simplified JSON representation
// This removes type information and returns plain values
func FirestoreValueToJSON(value any) any {
if value == nil {
return nil
}
switch v := value.(type) {
case time.Time:
return v.Format(time.RFC3339Nano)
case *latlng.LatLng:
return map[string]any{
"latitude": v.Latitude,
"longitude": v.Longitude,
}
case []byte:
return base64.StdEncoding.EncodeToString(v)
case []any:
result := make([]any, len(v))
for i, item := range v {
result[i] = FirestoreValueToJSON(item)
}
return result
case map[string]any:
result := make(map[string]any)
for k, val := range v {
result[k] = FirestoreValueToJSON(val)
}
return result
case *firestore.DocumentRef:
return v.Path
default:
return value
}
}
// BuildQuery constructs the Firestore query from parameters
func (s *Source) BuildQuery(collectionPath string, filter firestore.EntityFilter, selectFields []string, field string, direction firestore.Direction, limit int, analyzeQuery bool) (*firestore.Query, error) {
collection := s.FirestoreClient().Collection(collectionPath)
query := collection.Query
// Process and apply filters if template is provided
if filter != nil {
query = query.WhereEntity(filter)
}
if len(selectFields) > 0 {
query = query.Select(selectFields...)
}
if field != "" {
query = query.OrderBy(field, direction)
}
query = query.Limit(limit)
// Apply analyze options if enabled
if analyzeQuery {
query = query.WithRunOptions(firestore.ExplainOptions{
Analyze: true,
})
}
return &query, nil
}
// QueryResult represents a document result from the query
type QueryResult struct {
ID string `json:"id"`
Path string `json:"path"`
Data map[string]any `json:"data"`
CreateTime any `json:"createTime,omitempty"`
UpdateTime any `json:"updateTime,omitempty"`
ReadTime any `json:"readTime,omitempty"`
}
// QueryResponse represents the full response including optional metrics
type QueryResponse struct {
Documents []QueryResult `json:"documents"`
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
}
// ExecuteQuery runs the query and formats the results
func (s *Source) ExecuteQuery(ctx context.Context, query *firestore.Query, analyzeQuery bool) (any, error) {
docIterator := query.Documents(ctx)
docs, err := docIterator.GetAll()
if err != nil {
return nil, fmt.Errorf("failed to execute query: %w", err)
}
// Convert results to structured format
results := make([]QueryResult, len(docs))
for i, doc := range docs {
results[i] = QueryResult{
ID: doc.Ref.ID,
Path: doc.Ref.Path,
Data: doc.Data(),
CreateTime: doc.CreateTime,
UpdateTime: doc.UpdateTime,
ReadTime: doc.ReadTime,
}
}
// Return with explain metrics if requested
if analyzeQuery {
explainMetrics, err := getExplainMetrics(docIterator)
if err == nil && explainMetrics != nil {
response := QueryResponse{
Documents: results,
ExplainMetrics: explainMetrics,
}
return response, nil
}
}
return results, nil
}
// getExplainMetrics extracts explain metrics from the query iterator
func getExplainMetrics(docIterator *firestore.DocumentIterator) (map[string]any, error) {
explainMetrics, err := docIterator.ExplainMetrics()
if err != nil || explainMetrics == nil {
return nil, err
}
metricsData := make(map[string]any)
// Add plan summary if available
if explainMetrics.PlanSummary != nil {
planSummary := make(map[string]any)
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
metricsData["planSummary"] = planSummary
}
// Add execution stats if available
if explainMetrics.ExecutionStats != nil {
executionStats := make(map[string]any)
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
}
if explainMetrics.ExecutionStats.DebugStats != nil {
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
}
metricsData["executionStats"] = executionStats
}
return metricsData, nil
}
func (s *Source) GetDocuments(ctx context.Context, documentPaths []string) ([]any, error) {
// Create document references from paths
docRefs := make([]*firestore.DocumentRef, len(documentPaths))
for i, path := range documentPaths {
docRefs[i] = s.FirestoreClient().Doc(path)
}
// Get all documents
snapshots, err := s.FirestoreClient().GetAll(ctx, docRefs)
if err != nil {
return nil, fmt.Errorf("failed to get documents: %w", err)
}
// Convert snapshots to response data
results := make([]any, len(snapshots))
for i, snapshot := range snapshots {
docData := make(map[string]any)
docData["path"] = documentPaths[i]
docData["exists"] = snapshot.Exists()
if snapshot.Exists() {
docData["data"] = snapshot.Data()
docData["createTime"] = snapshot.CreateTime
docData["updateTime"] = snapshot.UpdateTime
docData["readTime"] = snapshot.ReadTime
}
results[i] = docData
}
return results, nil
}
func (s *Source) AddDocuments(ctx context.Context, collectionPath string, documentData any, returnData bool) (map[string]any, error) {
// Get the collection reference
collection := s.FirestoreClient().Collection(collectionPath)
// Add the document to the collection
docRef, writeResult, err := collection.Add(ctx, documentData)
if err != nil {
return nil, fmt.Errorf("failed to add document: %w", err)
}
// Build the response
response := map[string]any{
"documentPath": docRef.Path,
"createTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
}
// Add document data if requested
if returnData {
// Fetch the updated document to return the current state
snapshot, err := docRef.Get(ctx)
if err != nil {
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
}
// Convert the document data back to simple JSON format
simplifiedData := FirestoreValueToJSON(snapshot.Data())
response["documentData"] = simplifiedData
}
return response, nil
}
func (s *Source) UpdateDocument(ctx context.Context, documentPath string, updates []firestore.Update, documentData any, returnData bool) (map[string]any, error) {
// Get the document reference
docRef := s.FirestoreClient().Doc(documentPath)
// Prepare update data
var writeResult *firestore.WriteResult
var writeErr error
if len(updates) > 0 {
writeResult, writeErr = docRef.Update(ctx, updates)
} else {
writeResult, writeErr = docRef.Set(ctx, documentData, firestore.MergeAll)
}
if writeErr != nil {
return nil, fmt.Errorf("failed to update document: %w", writeErr)
}
// Build the response
response := map[string]any{
"documentPath": docRef.Path,
"updateTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
}
// Add document data if requested
if returnData {
// Fetch the updated document to return the current state
snapshot, err := docRef.Get(ctx)
if err != nil {
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
}
// Convert the document data to simple JSON format
simplifiedData := FirestoreValueToJSON(snapshot.Data())
response["documentData"] = simplifiedData
}
return response, nil
}
func (s *Source) DeleteDocuments(ctx context.Context, documentPaths []string) ([]any, error) {
// Create a BulkWriter to handle multiple deletions efficiently
bulkWriter := s.FirestoreClient().BulkWriter(ctx)
// Keep track of jobs for each document
jobs := make([]*firestore.BulkWriterJob, len(documentPaths))
// Add all delete operations to the BulkWriter
for i, path := range documentPaths {
docRef := s.FirestoreClient().Doc(path)
job, err := bulkWriter.Delete(docRef)
if err != nil {
return nil, fmt.Errorf("failed to add delete operation for document %q: %w", path, err)
}
jobs[i] = job
}
// End the BulkWriter to execute all operations
bulkWriter.End()
// Collect results
results := make([]any, len(documentPaths))
for i, job := range jobs {
docData := make(map[string]any)
docData["path"] = documentPaths[i]
// Wait for the job to complete and get the result
_, err := job.Results()
if err != nil {
docData["success"] = false
docData["error"] = err.Error()
} else {
docData["success"] = true
}
results[i] = docData
}
return results, nil
}
func (s *Source) ListCollections(ctx context.Context, parentPath string) ([]any, error) {
var collectionRefs []*firestore.CollectionRef
var err error
if parentPath != "" {
// List subcollections of the specified document
docRef := s.FirestoreClient().Doc(parentPath)
collectionRefs, err = docRef.Collections(ctx).GetAll()
if err != nil {
return nil, fmt.Errorf("failed to list subcollections of document %q: %w", parentPath, err)
}
} else {
// List root collections
collectionRefs, err = s.FirestoreClient().Collections(ctx).GetAll()
if err != nil {
return nil, fmt.Errorf("failed to list root collections: %w", err)
}
}
// Convert collection references to response data
results := make([]any, len(collectionRefs))
for i, collRef := range collectionRefs {
collData := make(map[string]any)
collData["id"] = collRef.ID
collData["path"] = collRef.Path
// If this is a subcollection, include parent information
if collRef.Parent != nil {
collData["parent"] = collRef.Parent.Path
}
results[i] = collData
}
return results, nil
}
func (s *Source) GetRules(ctx context.Context) (any, error) {
// Get the latest release for Firestore
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", s.GetProjectId(), s.GetDatabaseId())
release, err := s.FirebaseRulesClient().Projects.Releases.Get(releaseName).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to get latest Firestore release: %w", err)
}
if release.RulesetName == "" {
return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", s.GetProjectId(), s.GetDatabaseId())
}
// Get the ruleset content
ruleset, err := s.FirebaseRulesClient().Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to get ruleset content: %w", err)
}
if ruleset.Source == nil || len(ruleset.Source.Files) == 0 {
return nil, fmt.Errorf("no rules files found in ruleset")
}
return ruleset, nil
}
// SourcePosition represents the location of an issue in the source
type SourcePosition struct {
FileName string `json:"fileName,omitempty"`
Line int64 `json:"line"` // 1-based
Column int64 `json:"column"` // 1-based
CurrentOffset int64 `json:"currentOffset"` // 0-based, inclusive start
EndOffset int64 `json:"endOffset"` // 0-based, exclusive end
}
// Issue represents a validation issue in the rules
type Issue struct {
SourcePosition SourcePosition `json:"sourcePosition"`
Description string `json:"description"`
Severity string `json:"severity"`
}
// ValidationResult represents the result of rules validation
type ValidationResult struct {
Valid bool `json:"valid"`
IssueCount int `json:"issueCount"`
FormattedIssues string `json:"formattedIssues,omitempty"`
RawIssues []Issue `json:"rawIssues,omitempty"`
}
func (s *Source) ValidateRules(ctx context.Context, sourceParam string) (any, error) {
// Create test request
testRequest := &firebaserules.TestRulesetRequest{
Source: &firebaserules.Source{
Files: []*firebaserules.File{
{
Name: "firestore.rules",
Content: sourceParam,
},
},
},
// We don't need test cases for validation only
TestSuite: &firebaserules.TestSuite{
TestCases: []*firebaserules.TestCase{},
},
}
// Call the test API
projectName := fmt.Sprintf("projects/%s", s.GetProjectId())
response, err := s.FirebaseRulesClient().Projects.Test(projectName, testRequest).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to validate rules: %w", err)
}
// Process the response
if len(response.Issues) == 0 {
return ValidationResult{
Valid: true,
IssueCount: 0,
FormattedIssues: "✓ No errors detected. Rules are valid.",
}, nil
}
// Convert issues to our format
issues := make([]Issue, len(response.Issues))
for i, issue := range response.Issues {
issues[i] = Issue{
Description: issue.Description,
Severity: issue.Severity,
SourcePosition: SourcePosition{
FileName: issue.SourcePosition.FileName,
Line: issue.SourcePosition.Line,
Column: issue.SourcePosition.Column,
CurrentOffset: issue.SourcePosition.CurrentOffset,
EndOffset: issue.SourcePosition.EndOffset,
},
}
}
// Format issues
sourceLines := strings.Split(sourceParam, "\n")
var formattedOutput []string
formattedOutput = append(formattedOutput, fmt.Sprintf("Found %d issue(s) in rules source:\n", len(issues)))
for _, issue := range issues {
issueString := fmt.Sprintf("%s: %s [Ln %d, Col %d]",
issue.Severity,
issue.Description,
issue.SourcePosition.Line,
issue.SourcePosition.Column)
if issue.SourcePosition.Line > 0 {
lineIndex := int(issue.SourcePosition.Line - 1) // 0-based index
if lineIndex >= 0 && lineIndex < len(sourceLines) {
errorLine := sourceLines[lineIndex]
issueString += fmt.Sprintf("\n```\n%s", errorLine)
// Add carets if we have column and offset information
if issue.SourcePosition.Column > 0 &&
issue.SourcePosition.CurrentOffset >= 0 &&
issue.SourcePosition.EndOffset > issue.SourcePosition.CurrentOffset {
startColumn := int(issue.SourcePosition.Column - 1) // 0-based
errorTokenLength := int(issue.SourcePosition.EndOffset - issue.SourcePosition.CurrentOffset)
if startColumn >= 0 && errorTokenLength > 0 && startColumn <= len(errorLine) {
padding := strings.Repeat(" ", startColumn)
carets := strings.Repeat("^", errorTokenLength)
issueString += fmt.Sprintf("\n%s%s", padding, carets)
}
}
issueString += "\n```"
}
}
formattedOutput = append(formattedOutput, issueString)
}
formattedIssues := strings.Join(formattedOutput, "\n\n")
return ValidationResult{
Valid: false,
IssueCount: len(issues),
FormattedIssues: formattedIssues,
RawIssues: issues,
}, nil
}
func initFirestoreConnection(
ctx context.Context,
tracer trace.Tracer,
@@ -594,7 +120,7 @@ func initFirestoreConnection(
project string,
database string,
) (*firestore.Client, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
userAgent, err := util.UserAgentFromContext(ctx)

View File

@@ -16,7 +16,6 @@ package firestore_test
import (
"testing"
"time"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
@@ -42,7 +41,7 @@ func TestParseFromYamlFirestore(t *testing.T) {
want: server.SourceConfigs{
"my-firestore": firestore.Config{
Name: "my-firestore",
Type: firestore.SourceType,
Kind: firestore.SourceKind,
Project: "my-project",
Database: "",
},
@@ -60,7 +59,7 @@ func TestParseFromYamlFirestore(t *testing.T) {
want: server.SourceConfigs{
"my-firestore": firestore.Config{
Name: "my-firestore",
Type: firestore.SourceType,
Kind: firestore.SourceKind,
Project: "my-project",
Database: "my-database",
},
@@ -129,37 +128,3 @@ func TestFailParseFromYamlFirestore(t *testing.T) {
})
}
}
func TestFirestoreValueToJSON_RoundTrip(t *testing.T) {
// Test round-trip conversion
original := map[string]any{
"name": "Test",
"count": int64(42),
"price": 19.99,
"active": true,
"tags": []any{"tag1", "tag2"},
"metadata": map[string]any{
"created": time.Now(),
},
"nullField": nil,
}
// Convert to JSON representation
jsonRepresentation := firestore.FirestoreValueToJSON(original)
// Verify types are simplified
jsonMap, ok := jsonRepresentation.(map[string]any)
if !ok {
t.Fatalf("Expected map, got %T", jsonRepresentation)
}
// Time should be converted to string
metadata, ok := jsonMap["metadata"].(map[string]any)
if !ok {
t.Fatalf("metadata should be a map, got %T", jsonMap["metadata"])
}
_, ok = metadata["created"].(string)
if !ok {
t.Errorf("created should be a string, got %T", metadata["created"])
}
}

View File

@@ -16,9 +16,7 @@ package http
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"
@@ -29,14 +27,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "http"
const SourceKind string = "http"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -50,7 +48,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
BaseURL string `yaml:"baseUrl"`
Timeout string `yaml:"timeout"`
DefaultHeaders map[string]string `yaml:"headers"`
@@ -58,8 +56,8 @@ type Config struct {
DisableSslVerification bool `yaml:"disableSslVerification"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
// Initialize initializes an HTTP Source instance.
@@ -122,8 +120,8 @@ type Source struct {
client *http.Client
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -145,28 +143,3 @@ func (s *Source) HttpQueryParams() map[string]string {
func (s *Source) Client() *http.Client {
return s.client
}
func (s *Source) RunRequest(req *http.Request) (any, error) {
// Make request and fetch response
resp, err := s.Client().Do(req)
if err != nil {
return nil, fmt.Errorf("error making HTTP request: %s", err)
}
defer resp.Body.Close()
var body []byte
body, err = io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode > 299 {
return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body))
}
var data any
if err = json.Unmarshal(body, &data); err != nil {
// if unable to unmarshal data, return result as string.
return string(body), nil
}
return data, nil
}

View File

@@ -42,7 +42,7 @@ func TestParseFromYamlHttp(t *testing.T) {
want: map[string]sources.SourceConfig{
"my-http-instance": http.Config{
Name: "my-http-instance",
Type: http.SourceType,
Kind: http.SourceKind,
BaseURL: "http://test_server/",
Timeout: "30s",
DisableSslVerification: false,
@@ -68,7 +68,7 @@ func TestParseFromYamlHttp(t *testing.T) {
want: map[string]sources.SourceConfig{
"my-http-instance": http.Config{
Name: "my-http-instance",
Type: http.SourceType,
Kind: http.SourceKind,
BaseURL: "http://test_server/",
Timeout: "10s",
DefaultHeaders: map[string]string{"Authorization": "test_header", "Custom-Header": "custom"},

View File

@@ -15,9 +15,7 @@ package looker
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"strings"
"time"
@@ -33,14 +31,14 @@ import (
v4 "github.com/looker-open-source/sdk-codegen/go/sdk/v4"
)
const SourceType string = "looker"
const SourceKind string = "looker"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -64,7 +62,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
BaseURL string `yaml:"base_url" validate:"required"`
ClientId string `yaml:"client_id"`
ClientSecret string `yaml:"client_secret"`
@@ -79,8 +77,8 @@ type Config struct {
SessionLength int64 `yaml:"sessionLength"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
// Initialize initializes a Looker Source instance.
@@ -154,8 +152,8 @@ type Source struct {
AuthTokenHeaderName string
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -210,49 +208,6 @@ func (s *Source) LookerSessionLength() int64 {
return s.SessionLength
}
// Make types for RoundTripper
type transportWithAuthHeader struct {
Base http.RoundTripper
AuthToken string
}
func (t *transportWithAuthHeader) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Set("x-looker-appid", "go-sdk")
req.Header.Set("Authorization", t.AuthToken)
return t.Base.RoundTrip(req)
}
func (s *Source) GetLookerSDK(accessToken string) (*v4.LookerSDK, error) {
if s.UseClientAuthorization() {
if accessToken == "" {
return nil, fmt.Errorf("no access token supplied with request")
}
session := rtl.NewAuthSession(*s.LookerApiSettings())
// Configure base transport with TLS
transport := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: !s.LookerApiSettings().VerifySsl,
},
}
// Build transport for end user token
session.Client = http.Client{
Transport: &transportWithAuthHeader{
Base: transport,
AuthToken: accessToken,
},
}
// return SDK with new Transport
return v4.NewLookerSDK(session), nil
}
if s.LookerClient() == nil {
return nil, fmt.Errorf("client id or client secret not valid")
}
return s.LookerClient(), nil
}
func initGoogleCloudConnection(ctx context.Context) (oauth2.TokenSource, error) {
cred, err := google.FindDefaultCredentials(ctx, geminidataanalytics.DefaultAuthScopes()...)
if err != nil {

View File

@@ -44,7 +44,7 @@ func TestParseFromYamlLooker(t *testing.T) {
want: map[string]sources.SourceConfig{
"my-looker-instance": looker.Config{
Name: "my-looker-instance",
Type: looker.SourceType,
Kind: looker.SourceKind,
BaseURL: "http://example.looker.com/",
ClientId: "jasdl;k;tjl",
ClientSecret: "sdakl;jgflkasdfkfg",

View File

@@ -27,14 +27,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "mindsdb"
const SourceKind string = "mindsdb"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -48,7 +48,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
@@ -57,8 +57,8 @@ type Config struct {
QueryTimeout string `yaml:"queryTimeout"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -86,8 +86,8 @@ type Source struct {
Pool *sql.DB
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -159,7 +159,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
func initMindsDBConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Configure the driver to connect to the database

View File

@@ -45,7 +45,7 @@ func TestParseFromYamlMindsDB(t *testing.T) {
want: server.SourceConfigs{
"my-mindsdb-instance": mindsdb.Config{
Name: "my-mindsdb-instance",
Type: mindsdb.SourceType,
Kind: mindsdb.SourceKind,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -70,7 +70,7 @@ func TestParseFromYamlMindsDB(t *testing.T) {
want: server.SourceConfigs{
"my-mindsdb-instance": mindsdb.Config{
Name: "my-mindsdb-instance",
Type: mindsdb.SourceType,
Kind: mindsdb.SourceKind,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",

View File

@@ -16,27 +16,24 @@ package mongodb
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "mongodb"
const SourceKind string = "mongodb"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -50,12 +47,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Uri string `yaml:"uri" validate:"required"` // MongoDB Atlas connection URI
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -84,8 +81,8 @@ type Source struct {
Client *mongo.Client
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -96,204 +93,9 @@ func (s *Source) MongoClient() *mongo.Client {
return s.Client
}
func parseData(ctx context.Context, cur *mongo.Cursor) ([]any, error) {
var data = []any{}
err := cur.All(ctx, &data)
if err != nil {
return nil, err
}
var final []any
for _, item := range data {
tmp, _ := bson.MarshalExtJSON(item, false, false)
var tmp2 any
err = json.Unmarshal(tmp, &tmp2)
if err != nil {
return nil, err
}
final = append(final, tmp2)
}
return final, err
}
func (s *Source) Aggregate(ctx context.Context, pipelineString string, canonical, readOnly bool, database, collection string) ([]any, error) {
var pipeline = []bson.M{}
err := bson.UnmarshalExtJSON([]byte(pipelineString), canonical, &pipeline)
if err != nil {
return nil, err
}
if readOnly {
//fail if we do a merge or an out
for _, stage := range pipeline {
for key := range stage {
if key == "$merge" || key == "$out" {
return nil, fmt.Errorf("this is not a read-only pipeline: %+v", stage)
}
}
}
}
cur, err := s.MongoClient().Database(database).Collection(collection).Aggregate(ctx, pipeline)
if err != nil {
return nil, err
}
defer cur.Close(ctx)
res, err := parseData(ctx, cur)
if err != nil {
return nil, err
}
if res == nil {
return []any{}, nil
}
return res, err
}
func (s *Source) Find(ctx context.Context, filterString, database, collection string, opts *options.FindOptions) ([]any, error) {
var filter = bson.D{}
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
if err != nil {
return nil, err
}
cur, err := s.MongoClient().Database(database).Collection(collection).Find(ctx, filter, opts)
if err != nil {
return nil, err
}
defer cur.Close(ctx)
return parseData(ctx, cur)
}
func (s *Source) FindOne(ctx context.Context, filterString, database, collection string, opts *options.FindOneOptions) ([]any, error) {
var filter = bson.D{}
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
if err != nil {
return nil, err
}
res := s.MongoClient().Database(database).Collection(collection).FindOne(ctx, filter, opts)
if res.Err() != nil {
return nil, res.Err()
}
var data any
err = res.Decode(&data)
if err != nil {
return nil, err
}
var final []any
tmp, _ := bson.MarshalExtJSON(data, false, false)
var tmp2 any
err = json.Unmarshal(tmp, &tmp2)
if err != nil {
return nil, err
}
final = append(final, tmp2)
return final, err
}
func (s *Source) InsertMany(ctx context.Context, jsonData string, canonical bool, database, collection string) ([]any, error) {
var data = []any{}
err := bson.UnmarshalExtJSON([]byte(jsonData), canonical, &data)
if err != nil {
return nil, err
}
res, err := s.MongoClient().Database(database).Collection(collection).InsertMany(ctx, data, options.InsertMany())
if err != nil {
return nil, err
}
return res.InsertedIDs, nil
}
func (s *Source) InsertOne(ctx context.Context, jsonData string, canonical bool, database, collection string) (any, error) {
var data any
err := bson.UnmarshalExtJSON([]byte(jsonData), canonical, &data)
if err != nil {
return nil, err
}
res, err := s.MongoClient().Database(database).Collection(collection).InsertOne(ctx, data, options.InsertOne())
if err != nil {
return nil, err
}
return res.InsertedID, nil
}
func (s *Source) UpdateMany(ctx context.Context, filterString string, canonical bool, updateString, database, collection string, upsert bool) ([]any, error) {
var filter = bson.D{}
err := bson.UnmarshalExtJSON([]byte(filterString), canonical, &filter)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal filter string: %w", err)
}
var update = bson.D{}
err = bson.UnmarshalExtJSON([]byte(updateString), false, &update)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal update string: %w", err)
}
res, err := s.MongoClient().Database(database).Collection(collection).UpdateMany(ctx, filter, update, options.Update().SetUpsert(upsert))
if err != nil {
return nil, fmt.Errorf("error updating collection: %w", err)
}
return []any{res.ModifiedCount, res.UpsertedCount, res.MatchedCount}, nil
}
func (s *Source) UpdateOne(ctx context.Context, filterString string, canonical bool, updateString, database, collection string, upsert bool) (any, error) {
var filter = bson.D{}
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal filter string: %w", err)
}
var update = bson.D{}
err = bson.UnmarshalExtJSON([]byte(updateString), canonical, &update)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal update string: %w", err)
}
res, err := s.MongoClient().Database(database).Collection(collection).UpdateOne(ctx, filter, update, options.Update().SetUpsert(upsert))
if err != nil {
return nil, fmt.Errorf("error updating collection: %w", err)
}
return res.ModifiedCount, nil
}
func (s *Source) DeleteMany(ctx context.Context, filterString, database, collection string) (any, error) {
var filter = bson.D{}
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
if err != nil {
return nil, err
}
res, err := s.MongoClient().Database(database).Collection(collection).DeleteMany(ctx, filter, options.Delete())
if err != nil {
return nil, err
}
if res.DeletedCount == 0 {
return nil, errors.New("no document found")
}
return res.DeletedCount, nil
}
func (s *Source) DeleteOne(ctx context.Context, filterString, database, collection string) (any, error) {
var filter = bson.D{}
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
if err != nil {
return nil, err
}
res, err := s.MongoClient().Database(database).Collection(collection).DeleteOne(ctx, filter, options.Delete())
if err != nil {
return nil, err
}
return res.DeletedCount, nil
}
func initMongoDBClient(ctx context.Context, tracer trace.Tracer, name, uri string) (*mongo.Client, error) {
// Start a tracing span
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
userAgent, err := util.UserAgentFromContext(ctx)

View File

@@ -41,7 +41,7 @@ func TestParseFromYamlMongoDB(t *testing.T) {
want: server.SourceConfigs{
"mongo-db": mongodb.Config{
Name: "mongo-db",
Type: mongodb.SourceType,
Kind: mongodb.SourceKind,
Uri: "mongodb+srv://username:password@host/dbname",
},
},

View File

@@ -28,14 +28,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "mssql"
const SourceKind string = "mssql"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -50,7 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
// Cloud SQL MSSQL configs
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
@@ -59,9 +59,9 @@ type Config struct {
Encrypt string `yaml:"encrypt"`
}
func (r Config) SourceConfigType() string {
func (r Config) SourceConfigKind() string {
// Returns Cloud SQL MSSQL source kind
return SourceType
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -91,9 +91,9 @@ type Source struct {
Db *sql.DB
}
func (s *Source) SourceType() string {
func (s *Source) SourceKind() string {
// Returns Cloud SQL MSSQL source kind
return SourceType
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -156,7 +156,7 @@ func initMssqlConnection(
error,
) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
userAgent, err := util.UserAgentFromContext(ctx)

View File

@@ -45,7 +45,7 @@ func TestParseFromYamlMssql(t *testing.T) {
want: server.SourceConfigs{
"my-mssql-instance": mssql.Config{
Name: "my-mssql-instance",
Type: mssql.SourceType,
Kind: mssql.SourceKind,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -70,7 +70,7 @@ func TestParseFromYamlMssql(t *testing.T) {
want: server.SourceConfigs{
"my-mssql-instance": mssql.Config{
Name: "my-mssql-instance",
Type: mssql.SourceType,
Kind: mssql.SourceKind,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",

View File

@@ -30,14 +30,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "mysql"
const SourceKind string = "mysql"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -51,7 +51,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
@@ -61,8 +61,8 @@ type Config struct {
QueryParams map[string]string `yaml:"queryParams"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -90,8 +90,8 @@ type Source struct {
Pool *sql.DB
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -158,7 +158,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
func initMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string, queryParams map[string]string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Build query parameters via url.Values for deterministic order and proper escaping.

View File

@@ -50,7 +50,7 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
want: server.SourceConfigs{
"my-mysql-instance": mysql.Config{
Name: "my-mysql-instance",
Type: mysql.SourceType,
Kind: mysql.SourceKind,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -75,7 +75,7 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
want: server.SourceConfigs{
"my-mysql-instance": mysql.Config{
Name: "my-mysql-instance",
Type: mysql.SourceType,
Kind: mysql.SourceKind,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -103,7 +103,7 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
want: server.SourceConfigs{
"my-mysql-instance": mysql.Config{
Name: "my-mysql-instance",
Type: mysql.SourceType,
Kind: mysql.SourceKind,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -211,7 +211,7 @@ func TestFailInitialization(t *testing.T) {
cfg := mysql.Config{
Name: "instance",
Type: "mysql",
Kind: "mysql",
Host: "localhost",
Port: "3306",
Database: "db",

View File

@@ -29,7 +29,7 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "neo4j"
const SourceKind string = "neo4j"
var sourceClassifier *classifier.QueryClassifier = classifier.NewQueryClassifier()
@@ -37,8 +37,8 @@ var sourceClassifier *classifier.QueryClassifier = classifier.NewQueryClassifier
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -52,15 +52,15 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Uri string `yaml:"uri" validate:"required"`
User string `yaml:"user" validate:"required"`
Password string `yaml:"password" validate:"required"`
Database string `yaml:"database" validate:"required"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -91,8 +91,8 @@ type Source struct {
Driver neo4j.DriverWithContext
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -182,7 +182,7 @@ func addPlanChildren(p neo4j.Plan) []map[string]any {
func initNeo4jDriver(ctx context.Context, tracer trace.Tracer, uri, user, password, name string) (neo4j.DriverWithContext, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
auth := neo4j.BasicAuth(user, password, "")

View File

@@ -44,7 +44,7 @@ func TestParseFromYamlNeo4j(t *testing.T) {
want: server.SourceConfigs{
"my-neo4j-instance": neo4j.Config{
Name: "my-neo4j-instance",
Type: neo4j.SourceType,
Kind: neo4j.SourceKind,
Uri: "neo4j+s://my-host:7687",
Database: "my_db",
User: "my_user",

View File

@@ -27,14 +27,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "oceanbase"
const SourceKind string = "oceanbase"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -48,7 +48,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
@@ -57,8 +57,8 @@ type Config struct {
QueryTimeout string `yaml:"queryTimeout"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -86,8 +86,8 @@ type Source struct {
Pool *sql.DB
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -153,7 +153,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
}
func initOceanBaseConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
_, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
_, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", user, pass, host, port, dbname)

View File

@@ -46,7 +46,7 @@ func TestParseFromYamlOceanBase(t *testing.T) {
want: server.SourceConfigs{
"my-oceanbase-instance": oceanbase.Config{
Name: "my-oceanbase-instance",
Type: oceanbase.SourceType,
Kind: oceanbase.SourceKind,
Host: "0.0.0.0",
Port: "2881",
Database: "ob_db",
@@ -71,7 +71,7 @@ func TestParseFromYamlOceanBase(t *testing.T) {
want: server.SourceConfigs{
"my-oceanbase-instance": oceanbase.Config{
Name: "my-oceanbase-instance",
Type: oceanbase.SourceType,
Kind: oceanbase.SourceKind,
Host: "0.0.0.0",
Port: "2881",
Database: "ob_db",

View File

@@ -18,14 +18,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "oracle"
const SourceKind string = "oracle"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -45,7 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
ConnectionString string `yaml:"connectionString,omitempty"`
TnsAlias string `yaml:"tnsAlias,omitempty"`
TnsAdmin string `yaml:"tnsAdmin,omitempty"`
@@ -95,8 +95,8 @@ func (c Config) validate() error {
return nil
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -124,8 +124,8 @@ type Source struct {
DB *sql.DB
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -239,7 +239,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Config) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, config.Name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, config.Name)
defer span.End()
logger, err := util.LoggerFromContext(ctx)

View File

@@ -33,7 +33,7 @@ func TestParseFromYamlOracle(t *testing.T) {
want: server.SourceConfigs{
"my-oracle-cs": oracle.Config{
Name: "my-oracle-cs",
Type: oracle.SourceType,
Kind: oracle.SourceKind,
ConnectionString: "my-host:1521/XEPDB1",
User: "my_user",
Password: "my_pass",
@@ -56,7 +56,7 @@ func TestParseFromYamlOracle(t *testing.T) {
want: server.SourceConfigs{
"my-oracle-host": oracle.Config{
Name: "my-oracle-host",
Type: oracle.SourceType,
Kind: oracle.SourceKind,
Host: "my-host",
Port: 1521,
ServiceName: "ORCLPDB",
@@ -81,7 +81,7 @@ func TestParseFromYamlOracle(t *testing.T) {
want: server.SourceConfigs{
"my-oracle-tns-oci": oracle.Config{
Name: "my-oracle-tns-oci",
Type: oracle.SourceType,
Kind: oracle.SourceKind,
TnsAlias: "FINANCE_DB",
TnsAdmin: "/opt/oracle/network/admin",
User: "my_user",

View File

@@ -28,14 +28,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "postgres"
const SourceKind string = "postgres"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -49,7 +49,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
@@ -58,8 +58,8 @@ type Config struct {
QueryParams map[string]string `yaml:"queryParams"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -87,8 +87,8 @@ type Source struct {
Pool *pgxpool.Pool
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -128,7 +128,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string) (*pgxpool.Pool, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
userAgent, err := util.UserAgentFromContext(ctx)
if err != nil {

View File

@@ -47,7 +47,7 @@ func TestParseFromYamlPostgres(t *testing.T) {
want: server.SourceConfigs{
"my-pg-instance": postgres.Config{
Name: "my-pg-instance",
Type: postgres.SourceType,
Kind: postgres.SourceKind,
Host: "my-host",
Port: "my-port",
Database: "my_db",
@@ -74,7 +74,7 @@ func TestParseFromYamlPostgres(t *testing.T) {
want: server.SourceConfigs{
"my-pg-instance": postgres.Config{
Name: "my-pg-instance",
Type: postgres.SourceType,
Kind: postgres.SourceKind,
Host: "my-host",
Port: "my-port",
Database: "my_db",

View File

@@ -24,14 +24,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "redis"
const SourceKind string = "redis"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -45,7 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Address []string `yaml:"address" validate:"required"`
Username string `yaml:"username"`
Password string `yaml:"password"`
@@ -54,8 +54,8 @@ type Config struct {
ClusterEnabled bool `yaml:"clusterEnabled"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
// RedisClient is an interface for `redis.Client` and `redis.ClusterClient
@@ -141,8 +141,8 @@ type Source struct {
Client RedisClient
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -43,7 +43,7 @@ func TestParseFromYamlRedis(t *testing.T) {
want: server.SourceConfigs{
"my-redis-instance": redis.Config{
Name: "my-redis-instance",
Type: redis.SourceType,
Kind: redis.SourceKind,
Address: []string{"127.0.0.1"},
ClusterEnabled: false,
UseGCPIAM: false,
@@ -66,7 +66,7 @@ func TestParseFromYamlRedis(t *testing.T) {
want: server.SourceConfigs{
"my-redis-instance": redis.Config{
Name: "my-redis-instance",
Type: redis.SourceType,
Kind: redis.SourceKind,
Address: []string{"127.0.0.1"},
Password: "my-pass",
Database: 1,

View File

@@ -16,31 +16,25 @@ package serverlessspark
import (
"context"
"encoding/json"
"fmt"
"time"
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
longrunning "cloud.google.com/go/longrunning/autogen"
"cloud.google.com/go/longrunning/autogen/longrunningpb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
"google.golang.org/protobuf/encoding/protojson"
)
const SourceType string = "serverless-spark"
const SourceKind string = "serverless-spark"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -54,13 +48,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Location string `yaml:"location" validate:"required"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -94,8 +88,8 @@ type Source struct {
OpsClient *longrunning.OperationsClient
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -127,168 +121,3 @@ func (s *Source) Close() error {
}
return nil
}
func (s *Source) CancelOperation(ctx context.Context, operation string) (any, error) {
req := &longrunningpb.CancelOperationRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", s.GetProject(), s.GetLocation(), operation),
}
client, err := s.GetOperationsClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get operations client: %w", err)
}
err = client.CancelOperation(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to cancel operation: %w", err)
}
return fmt.Sprintf("Cancelled [%s].", operation), nil
}
func (s *Source) CreateBatch(ctx context.Context, batch *dataprocpb.Batch) (map[string]any, error) {
req := &dataprocpb.CreateBatchRequest{
Parent: fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation()),
Batch: batch,
}
client := s.GetBatchControllerClient()
op, err := client.CreateBatch(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to create batch: %w", err)
}
meta, err := op.Metadata()
if err != nil {
return nil, fmt.Errorf("failed to get create batch op metadata: %w", err)
}
projectID, location, batchID, err := ExtractBatchDetails(meta.GetBatch())
if err != nil {
return nil, fmt.Errorf("error extracting batch details from name %q: %v", meta.GetBatch(), err)
}
consoleUrl := BatchConsoleURL(projectID, location, batchID)
logsUrl := BatchLogsURL(projectID, location, batchID, meta.GetCreateTime().AsTime(), time.Time{})
wrappedResult := map[string]any{
"opMetadata": meta,
"consoleUrl": consoleUrl,
"logsUrl": logsUrl,
}
return wrappedResult, nil
}
// ListBatchesResponse is the response from the list batches API.
type ListBatchesResponse struct {
Batches []Batch `json:"batches"`
NextPageToken string `json:"nextPageToken"`
}
// Batch represents a single batch job.
type Batch struct {
Name string `json:"name"`
UUID string `json:"uuid"`
State string `json:"state"`
Creator string `json:"creator"`
CreateTime string `json:"createTime"`
Operation string `json:"operation"`
ConsoleURL string `json:"consoleUrl"`
LogsURL string `json:"logsUrl"`
}
func (s *Source) ListBatches(ctx context.Context, ps *int, pt, filter string) (any, error) {
client := s.GetBatchControllerClient()
parent := fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation())
req := &dataprocpb.ListBatchesRequest{
Parent: parent,
OrderBy: "create_time desc",
}
if ps != nil {
req.PageSize = int32(*ps)
}
if pt != "" {
req.PageToken = pt
}
if filter != "" {
req.Filter = filter
}
it := client.ListBatches(ctx, req)
pager := iterator.NewPager(it, int(req.PageSize), req.PageToken)
var batchPbs []*dataprocpb.Batch
nextPageToken, err := pager.NextPage(&batchPbs)
if err != nil {
return nil, fmt.Errorf("failed to list batches: %w", err)
}
batches, err := ToBatches(batchPbs)
if err != nil {
return nil, err
}
return ListBatchesResponse{Batches: batches, NextPageToken: nextPageToken}, nil
}
// ToBatches converts a slice of protobuf Batch messages to a slice of Batch structs.
func ToBatches(batchPbs []*dataprocpb.Batch) ([]Batch, error) {
batches := make([]Batch, 0, len(batchPbs))
for _, batchPb := range batchPbs {
consoleUrl, err := BatchConsoleURLFromProto(batchPb)
if err != nil {
return nil, fmt.Errorf("error generating console url: %v", err)
}
logsUrl, err := BatchLogsURLFromProto(batchPb)
if err != nil {
return nil, fmt.Errorf("error generating logs url: %v", err)
}
batch := Batch{
Name: batchPb.Name,
UUID: batchPb.Uuid,
State: batchPb.State.Enum().String(),
Creator: batchPb.Creator,
CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339),
Operation: batchPb.Operation,
ConsoleURL: consoleUrl,
LogsURL: logsUrl,
}
batches = append(batches, batch)
}
return batches, nil
}
func (s *Source) GetBatch(ctx context.Context, name string) (map[string]any, error) {
client := s.GetBatchControllerClient()
req := &dataprocpb.GetBatchRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", s.GetProject(), s.GetLocation(), name),
}
batchPb, err := client.GetBatch(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to get batch: %w", err)
}
jsonBytes, err := protojson.Marshal(batchPb)
if err != nil {
return nil, fmt.Errorf("failed to marshal batch to JSON: %w", err)
}
var result map[string]any
if err := json.Unmarshal(jsonBytes, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal batch JSON: %w", err)
}
consoleUrl, err := BatchConsoleURLFromProto(batchPb)
if err != nil {
return nil, fmt.Errorf("error generating console url: %v", err)
}
logsUrl, err := BatchLogsURLFromProto(batchPb)
if err != nil {
return nil, fmt.Errorf("error generating logs url: %v", err)
}
wrappedResult := map[string]any{
"consoleUrl": consoleUrl,
"logsUrl": logsUrl,
"batch": result,
}
return wrappedResult, nil
}

View File

@@ -42,7 +42,7 @@ func TestParseFromYamlServerlessSpark(t *testing.T) {
want: server.SourceConfigs{
"my-instance": serverlessspark.Config{
Name: "my-instance",
Type: serverlessspark.SourceType,
Kind: serverlessspark.SourceKind,
Project: "my-project",
Location: "my-location",
},

View File

@@ -29,15 +29,15 @@ import (
"go.opentelemetry.io/otel/trace"
)
// SourceType for SingleStore source
const SourceType string = "singlestore"
// SourceKind for SingleStore source
const SourceKind string = "singlestore"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -52,7 +52,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
// Config holds the configuration parameters for connecting to a SingleStore database.
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
@@ -61,9 +61,9 @@ type Config struct {
QueryTimeout string `yaml:"queryTimeout"`
}
// SourceConfigType returns the kind of the source configuration.
func (r Config) SourceConfigType() string {
return SourceType
// SourceConfigKind returns the kind of the source configuration.
func (r Config) SourceConfigKind() string {
return SourceKind
}
// Initialize sets up the SingleStore connection pool and returns a Source.
@@ -93,9 +93,9 @@ type Source struct {
Pool *sql.DB
}
// SourceType returns the kind of the source configuration.
func (s *Source) SourceType() string {
return SourceType
// SourceKind returns the kind of the source configuration.
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -162,7 +162,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
func initSingleStoreConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Configure the driver to connect to the database

View File

@@ -45,7 +45,7 @@ func TestParseFromYaml(t *testing.T) {
want: server.SourceConfigs{
"my-s2-instance": singlestore.Config{
Name: "my-s2-instance",
Type: singlestore.SourceType,
Kind: singlestore.SourceKind,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -70,7 +70,7 @@ func TestParseFromYaml(t *testing.T) {
want: server.SourceConfigs{
"my-s2-instance": singlestore.Config{
Name: "my-s2-instance",
Type: singlestore.SourceType,
Kind: singlestore.SourceKind,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",

View File

@@ -25,14 +25,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "snowflake"
const SourceKind string = "snowflake"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -46,7 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Account string `yaml:"account" validate:"required"`
User string `yaml:"user" validate:"required"`
Password string `yaml:"password" validate:"required"`
@@ -56,8 +56,8 @@ type Config struct {
Role string `yaml:"role"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -85,8 +85,8 @@ type Source struct {
DB *sqlx.DB
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -137,7 +137,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
func initSnowflakeConnection(ctx context.Context, tracer trace.Tracer, name, account, user, password, database, schema, warehouse, role string) (*sqlx.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Set defaults for optional parameters

View File

@@ -45,7 +45,7 @@ func TestParseFromYamlSnowflake(t *testing.T) {
want: server.SourceConfigs{
"my-snowflake-instance": snowflake.Config{
Name: "my-snowflake-instance",
Type: snowflake.SourceType,
Kind: snowflake.SourceKind,
Account: "my-account",
User: "my_user",
Password: "my_pass",

View File

@@ -31,46 +31,46 @@ var sourceRegistry = make(map[string]SourceConfigFactory)
// Register registers a new source kind with its factory.
// It returns false if the kind is already registered.
func Register(sourceType string, factory SourceConfigFactory) bool {
if _, exists := sourceRegistry[sourceType]; exists {
func Register(kind string, factory SourceConfigFactory) bool {
if _, exists := sourceRegistry[kind]; exists {
// Source with this kind already exists, do not overwrite.
return false
}
sourceRegistry[sourceType] = factory
sourceRegistry[kind] = factory
return true
}
// DecodeConfig decodes a source configuration using the registered factory for the given kind.
func DecodeConfig(ctx context.Context, sourceType string, name string, decoder *yaml.Decoder) (SourceConfig, error) {
factory, found := sourceRegistry[sourceType]
func DecodeConfig(ctx context.Context, kind string, name string, decoder *yaml.Decoder) (SourceConfig, error) {
factory, found := sourceRegistry[kind]
if !found {
return nil, fmt.Errorf("unknown source kind: %q", sourceType)
return nil, fmt.Errorf("unknown source kind: %q", kind)
}
sourceConfig, err := factory(ctx, name, decoder)
if err != nil {
return nil, fmt.Errorf("unable to parse source %q as %q: %w", name, sourceType, err)
return nil, fmt.Errorf("unable to parse source %q as %q: %w", name, kind, err)
}
return sourceConfig, err
}
// SourceConfig is the interface for configuring a source.
type SourceConfig interface {
SourceConfigType() string
SourceConfigKind() string
Initialize(ctx context.Context, tracer trace.Tracer) (Source, error)
}
// Source is the interface for the source itself.
type Source interface {
SourceType() string
SourceKind() string
ToConfig() SourceConfig
}
// InitConnectionSpan adds a span for database pool connection initialization
func InitConnectionSpan(ctx context.Context, tracer trace.Tracer, sourceType, sourceName string) (context.Context, trace.Span) {
func InitConnectionSpan(ctx context.Context, tracer trace.Tracer, sourceKind, sourceName string) (context.Context, trace.Span) {
ctx, span := tracer.Start(
ctx,
"toolbox/server/source/connect",
trace.WithAttributes(attribute.String("source_type", sourceType)),
trace.WithAttributes(attribute.String("source_kind", sourceKind)),
trace.WithAttributes(attribute.String("source_name", sourceName)),
)
return ctx, span

View File

@@ -28,14 +28,14 @@ import (
"google.golang.org/api/iterator"
)
const SourceType string = "spanner"
const SourceKind string = "spanner"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -49,15 +49,15 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"kind" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Instance string `yaml:"instance" validate:"required"`
Dialect sources.Dialect `yaml:"dialect" validate:"required"`
Database string `yaml:"database" validate:"required"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -80,8 +80,8 @@ type Source struct {
Client *spanner.Client
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -171,7 +171,7 @@ func (s *Source) RunSQL(ctx context.Context, readOnly bool, statement string, pa
func initSpannerClient(ctx context.Context, tracer trace.Tracer, name, project, instance, dbname string) (*spanner.Client, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Configure the connection to the database

Some files were not shown because too many files have changed in this diff Show More