mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-22 05:48:08 -05:00
Compare commits
7 Commits
config-yam
...
docs/toolb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4a7db50333 | ||
|
|
62ceb4bb20 | ||
|
|
bddd439e51 | ||
|
|
4e0d7413d3 | ||
|
|
e4a51ad198 | ||
|
|
90356de685 | ||
|
|
b21734f382 |
@@ -59,13 +59,6 @@ You can manually trigger the bot by commenting on your Pull Request:
|
||||
* `/gemini summary`: Posts a summary of the changes in the pull request.
|
||||
* `/gemini help`: Overview of the available commands
|
||||
|
||||
## Guidelines for Pull Requests
|
||||
|
||||
1. Please keep your PR small for more thorough review and easier updates. In case of regression, it also allows us to roll back a single feature instead of multiple ones.
|
||||
1. For non-trivial changes, consider opening an issue and discussing it with the code owners first.
|
||||
1. Provide a good PR description as a record of what change is being made and why it was made. Link to a GitHub issue if it exists.
|
||||
1. Make sure your code is thoroughly tested with unit tests and integration tests. Remember to clean up the test instances properly in your code to avoid memory leaks.
|
||||
|
||||
## Adding a New Database Source or Tool
|
||||
|
||||
Please create an
|
||||
@@ -92,11 +85,11 @@ implementation](https://github.com/googleapis/genai-toolbox/blob/main/internal/s
|
||||
`newdb.go`. Create a `Config` struct to include all the necessary parameters
|
||||
for connecting to the database (e.g., host, port, username, password, database
|
||||
name) and a `Source` struct to store necessary parameters for tools (e.g.,
|
||||
Name, Type, connection object, additional config).
|
||||
Name, Kind, connection object, additional config).
|
||||
* **Implement the
|
||||
[`SourceConfig`](https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/internal/sources/sources.go#L57)
|
||||
interface**. This interface requires two methods:
|
||||
* `SourceConfigType() string`: Returns a unique string identifier for your
|
||||
* `SourceConfigKind() string`: Returns a unique string identifier for your
|
||||
data source (e.g., `"newdb"`).
|
||||
* `Initialize(ctx context.Context, tracer trace.Tracer) (Source, error)`:
|
||||
Creates a new instance of your data source and establishes a connection to
|
||||
@@ -104,7 +97,7 @@ implementation](https://github.com/googleapis/genai-toolbox/blob/main/internal/s
|
||||
* **Implement the
|
||||
[`Source`](https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/internal/sources/sources.go#L63)
|
||||
interface**. This interface requires one method:
|
||||
* `SourceType() string`: Returns the same string identifier as `SourceConfigType()`.
|
||||
* `SourceKind() string`: Returns the same string identifier as `SourceConfigKind()`.
|
||||
* **Implement `init()`** to register the new Source.
|
||||
* **Implement Unit Tests** in a file named `newdb_test.go`.
|
||||
|
||||
@@ -117,8 +110,6 @@ implementation](https://github.com/googleapis/genai-toolbox/blob/main/internal/s
|
||||
We recommend looking at an [example tool
|
||||
implementation](https://github.com/googleapis/genai-toolbox/tree/main/internal/tools/postgres/postgressql).
|
||||
|
||||
Remember to keep your PRs small. For example, if you are contributing a new Source, only include one or two core Tools within the same PR, the rest of the Tools can come in subsequent PRs.
|
||||
|
||||
* **Create a new directory** under `internal/tools` for your tool type (e.g., `internal/tools/newdb/newdbtool`).
|
||||
* **Define a configuration struct** for your tool in a file named `newdbtool.go`.
|
||||
Create a `Config` struct and a `Tool` struct to store necessary parameters for
|
||||
@@ -126,7 +117,7 @@ tools.
|
||||
* **Implement the
|
||||
[`ToolConfig`](https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/internal/tools/tools.go#L61)
|
||||
interface**. This interface requires one method:
|
||||
* `ToolConfigType() string`: Returns a unique string identifier for your tool
|
||||
* `ToolConfigKind() string`: Returns a unique string identifier for your tool
|
||||
(e.g., `"newdb-tool"`).
|
||||
* `Initialize(sources map[string]Source) (Tool, error)`: Creates a new
|
||||
instance of your tool and validates that it can connect to the specified
|
||||
@@ -172,8 +163,6 @@ tools.
|
||||
parameters][temp-param-doc]. Only run this test if template
|
||||
parameters apply to your tool.
|
||||
|
||||
* **Add additional tests** for the tools that are not covered by the predefined tests. Every tool must be tested!
|
||||
|
||||
* **Add the new database to the integration test workflow** in
|
||||
[integration.cloudbuild.yaml](.ci/integration.cloudbuild.yaml).
|
||||
|
||||
@@ -243,7 +232,7 @@ resources.
|
||||
| style | Update src code, with only formatting and whitespace updates (e.g. code formatter or linter changes). |
|
||||
|
||||
Pull requests should always add scope whenever possible. The scope is
|
||||
formatted as `<scope-resource>/<scope-type>` (e.g., `sources/postgres`, or
|
||||
formatted as `<scope-type>/<scope-kind>` (e.g., `sources/postgres`, or
|
||||
`tools/mssql-sql`).
|
||||
|
||||
Ideally, **each PR covers only one scope**, if this is
|
||||
@@ -255,4 +244,4 @@ resources.
|
||||
* **PR Description:** PR description should **always** be included. It should
|
||||
include a concise description of the changes, it's impact, along with a
|
||||
summary of the solution. If the PR is related to a specific issue, the issue
|
||||
number should be mentioned in the PR description (e.g. `Fixes #1`).
|
||||
number should be mentioned in the PR description (e.g. `Fixes #1`).
|
||||
@@ -954,7 +954,7 @@ For more details on configuring different types of sources, see the
|
||||
### Tools
|
||||
|
||||
The `tools` section of a `tools.yaml` define the actions an agent can take: what
|
||||
type of tool it is, which source(s) it affects, what parameters it uses, etc.
|
||||
kind of tool it is, which source(s) it affects, what parameters it uses, etc.
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
|
||||
136
cmd/root.go
136
cmd/root.go
@@ -15,7 +15,6 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
@@ -93,7 +92,6 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudmonitoring"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcloneinstance"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatebackup"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatedatabase"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreateusers"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlgetinstances"
|
||||
@@ -395,6 +393,7 @@ func NewCommand(opts ...Option) *Command {
|
||||
|
||||
type ToolsFile struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
AuthSources server.AuthServiceConfigs `yaml:"authSources"` // Deprecated: Kept for compatibility.
|
||||
AuthServices server.AuthServiceConfigs `yaml:"authServices"`
|
||||
EmbeddingModels server.EmbeddingModelConfigs `yaml:"embeddingModels"`
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
@@ -425,106 +424,6 @@ func parseEnv(input string) (string, error) {
|
||||
return output, err
|
||||
}
|
||||
|
||||
func convertToolsFile(ctx context.Context, raw []byte) ([]byte, error) {
|
||||
var input yaml.MapSlice
|
||||
decoder := yaml.NewDecoder(bytes.NewReader(raw), yaml.UseOrderedMap())
|
||||
if err := decoder.Decode(&input); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert raw MapSlice to a helper map for quick lookup
|
||||
// while keeping the values as MapSlices to preserve internal order
|
||||
resourceOrder := []string{}
|
||||
lookup := make(map[string]yaml.MapSlice)
|
||||
for _, item := range input {
|
||||
key, ok := item.Key.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected non-string key in input: %v", item.Key)
|
||||
}
|
||||
if slice, ok := item.Value.(yaml.MapSlice); ok {
|
||||
// convert authSources to authServices
|
||||
if key == "authSources" {
|
||||
key = "authServices"
|
||||
}
|
||||
// works even if lookup[key] is nil
|
||||
lookup[key] = append(lookup[key], slice...)
|
||||
// preserving the resource's order of original toolsFile
|
||||
if !slices.Contains(resourceOrder, key) {
|
||||
resourceOrder = append(resourceOrder, key)
|
||||
}
|
||||
} else {
|
||||
// toolsfile is already v2
|
||||
if key == "kind" {
|
||||
return raw, nil
|
||||
}
|
||||
return nil, fmt.Errorf("'%s' is not a map", key)
|
||||
}
|
||||
}
|
||||
// convert to tools file v2
|
||||
var buf bytes.Buffer
|
||||
encoder := yaml.NewEncoder(&buf)
|
||||
for _, kind := range resourceOrder {
|
||||
data, exists := lookup[kind]
|
||||
if !exists {
|
||||
// if this is skipped for all keys, the tools file is in v2
|
||||
continue
|
||||
}
|
||||
// Transform each entry
|
||||
for _, entry := range data {
|
||||
entryName, ok := entry.Key.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected non-string key for entry in '%s': %v", kind, entry.Key)
|
||||
}
|
||||
entryBody := ProcessValue(entry.Value, kind == "toolsets")
|
||||
|
||||
transformed := yaml.MapSlice{
|
||||
{Key: "kind", Value: kind},
|
||||
{Key: "name", Value: entryName},
|
||||
}
|
||||
|
||||
// Merge the transformed body into our result
|
||||
if bodySlice, ok := entryBody.(yaml.MapSlice); ok {
|
||||
transformed = append(transformed, bodySlice...)
|
||||
} else {
|
||||
return nil, fmt.Errorf("unable to convert entryBody to MapSlice")
|
||||
}
|
||||
|
||||
if err := encoder.Encode(transformed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// ProcessValue recursively looks for MapSlices to rename 'kind' -> 'type'
|
||||
func ProcessValue(v any, isToolset bool) any {
|
||||
switch val := v.(type) {
|
||||
case yaml.MapSlice:
|
||||
for i := range val {
|
||||
// Perform renaming
|
||||
if val[i].Key == "kind" {
|
||||
val[i].Key = "type"
|
||||
}
|
||||
// Recursive call for nested values (e.g., nested objects or lists)
|
||||
val[i].Value = ProcessValue(val[i].Value, false)
|
||||
}
|
||||
return val
|
||||
case []any:
|
||||
// Process lists: If it's a toolset top-level list, wrap it.
|
||||
if isToolset {
|
||||
return yaml.MapSlice{{Key: "tools", Value: val}}
|
||||
}
|
||||
// Otherwise, recurse into list items (to catch nested objects)
|
||||
for i := range val {
|
||||
val[i] = ProcessValue(val[i], false)
|
||||
}
|
||||
return val
|
||||
default:
|
||||
return val
|
||||
}
|
||||
}
|
||||
|
||||
// parseToolsFile parses the provided yaml into appropriate configs.
|
||||
func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
|
||||
var toolsFile ToolsFile
|
||||
@@ -535,13 +434,8 @@ func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
|
||||
}
|
||||
raw = []byte(output)
|
||||
|
||||
raw, err = convertToolsFile(ctx, raw)
|
||||
if err != nil {
|
||||
return toolsFile, fmt.Errorf("error converting tools file: %s", err)
|
||||
}
|
||||
|
||||
// Parse contents
|
||||
toolsFile.Sources, toolsFile.AuthServices, toolsFile.EmbeddingModels, toolsFile.Tools, toolsFile.Toolsets, toolsFile.Prompts, err = server.UnmarshalResourceConfig(ctx, raw)
|
||||
err = yaml.UnmarshalContext(ctx, raw, &toolsFile, yaml.Strict())
|
||||
if err != nil {
|
||||
return toolsFile, err
|
||||
}
|
||||
@@ -573,6 +467,18 @@ func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for conflicts and merge authSources (deprecated, but still support)
|
||||
for name, authSource := range file.AuthSources {
|
||||
if _, exists := merged.AuthSources[name]; exists {
|
||||
conflicts = append(conflicts, fmt.Sprintf("authSource '%s' (file #%d)", name, fileIndex+1))
|
||||
} else {
|
||||
if merged.AuthSources == nil {
|
||||
merged.AuthSources = make(server.AuthServiceConfigs)
|
||||
}
|
||||
merged.AuthSources[name] = authSource
|
||||
}
|
||||
}
|
||||
|
||||
// Check for conflicts and merge authServices
|
||||
for name, authService := range file.AuthServices {
|
||||
if _, exists := merged.AuthServices[name]; exists {
|
||||
@@ -1048,6 +954,20 @@ func run(cmd *Command) error {
|
||||
cmd.cfg.ToolsetConfigs = finalToolsFile.Toolsets
|
||||
cmd.cfg.PromptConfigs = finalToolsFile.Prompts
|
||||
|
||||
authSourceConfigs := finalToolsFile.AuthSources
|
||||
if authSourceConfigs != nil {
|
||||
cmd.logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` instead")
|
||||
|
||||
for k, v := range authSourceConfigs {
|
||||
if _, exists := cmd.cfg.AuthServiceConfigs[k]; exists {
|
||||
errMsg := fmt.Errorf("resource conflict detected: authSource '%s' has the same name as an existing authService. Please rename your authSource", k)
|
||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
cmd.cfg.AuthServiceConfigs[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
instrumentation, err := telemetry.CreateTelemetryInstrumentation(versionString)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("unable to create telemetry instrumentation: %w", err)
|
||||
|
||||
467
cmd/root_test.go
467
cmd/root_test.go
@@ -23,14 +23,12 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/auth/google"
|
||||
@@ -496,309 +494,6 @@ func TestDefaultLogLevel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToolsFile(t *testing.T) {
|
||||
ctx, cancelCtx := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancelCtx()
|
||||
pr, pw := io.Pipe()
|
||||
defer pw.Close()
|
||||
defer pr.Close()
|
||||
|
||||
logger, err := log.NewStdLogger(pw, pw, "DEBUG")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup logger %s", err)
|
||||
}
|
||||
ctx = util.WithLogger(ctx, logger)
|
||||
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want string
|
||||
isErr bool
|
||||
errStr string
|
||||
}{
|
||||
{
|
||||
desc: "basic convert",
|
||||
in: `
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
authServices:
|
||||
my-google-auth:
|
||||
kind: google
|
||||
clientId: testing-id
|
||||
tools:
|
||||
example_tool:
|
||||
kind: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
toolsets:
|
||||
example_toolset:
|
||||
- example_tool
|
||||
prompts:
|
||||
code_review:
|
||||
description: ask llm to analyze code quality
|
||||
messages:
|
||||
- content: "please review the following code for quality: {{.code}}"
|
||||
arguments:
|
||||
- name: code
|
||||
description: the code to review
|
||||
embeddingModels:
|
||||
gemini-model:
|
||||
kind: gemini
|
||||
model: gemini-embedding-001
|
||||
apiKey: some-key
|
||||
dimension: 768`,
|
||||
want: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
---
|
||||
kind: authServices
|
||||
name: my-google-auth
|
||||
type: google
|
||||
clientId: testing-id
|
||||
---
|
||||
kind: tools
|
||||
name: example_tool
|
||||
type: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
---
|
||||
kind: toolsets
|
||||
name: example_toolset
|
||||
tools:
|
||||
- example_tool
|
||||
---
|
||||
kind: prompts
|
||||
name: code_review
|
||||
description: ask llm to analyze code quality
|
||||
messages:
|
||||
- content: "please review the following code for quality: {{.code}}"
|
||||
arguments:
|
||||
- name: code
|
||||
description: the code to review
|
||||
---
|
||||
kind: embeddingModels
|
||||
name: gemini-model
|
||||
type: gemini
|
||||
model: gemini-embedding-001
|
||||
apiKey: some-key
|
||||
dimension: 768`,
|
||||
},
|
||||
{
|
||||
desc: "preserve resource order with grouping",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
authServices:
|
||||
my-google-auth:
|
||||
kind: google
|
||||
clientId: testing-id
|
||||
toolsets:
|
||||
example_toolset:
|
||||
- example_tool
|
||||
authSources:
|
||||
my-google-auth:
|
||||
kind: google
|
||||
clientId: testing-id`,
|
||||
want: `
|
||||
kind: tools
|
||||
name: example_tool
|
||||
type: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
---
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
---
|
||||
kind: authServices
|
||||
name: my-google-auth
|
||||
type: google
|
||||
clientId: testing-id
|
||||
---
|
||||
kind: authServices
|
||||
name: my-google-auth
|
||||
type: google
|
||||
clientId: testing-id
|
||||
---
|
||||
kind: toolsets
|
||||
name: example_toolset
|
||||
tools:
|
||||
- example_tool`,
|
||||
},
|
||||
{
|
||||
desc: "no convertion needed",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
---
|
||||
kind: tools
|
||||
name: example_tool
|
||||
type: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
---
|
||||
kind: toolsets
|
||||
name: example_toolset
|
||||
tools:
|
||||
- example_tool`,
|
||||
want: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
---
|
||||
kind: tools
|
||||
name: example_tool
|
||||
type: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
---
|
||||
kind: toolsets
|
||||
name: example_toolset
|
||||
tools:
|
||||
- example_tool`,
|
||||
},
|
||||
{
|
||||
desc: "invalid source",
|
||||
in: `sources: invalid`,
|
||||
isErr: true,
|
||||
errStr: "'sources' is not a map",
|
||||
},
|
||||
{
|
||||
desc: "invalid toolset",
|
||||
in: `toolsets: invalid`,
|
||||
isErr: true,
|
||||
errStr: "'toolsets' is not a map",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
output, err := convertToolsFile(ctx, []byte(tc.in))
|
||||
if tc.isErr {
|
||||
if err == nil {
|
||||
t.Fatalf("missing error: %s", tc.errStr)
|
||||
}
|
||||
if err.Error() != tc.errStr {
|
||||
t.Fatalf("invalid error string: got %s, want %s", err, tc.errStr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
var docs1, docs2 []yaml.MapSlice
|
||||
if docs1, err = decodeToMapSlice(string(output)); err != nil {
|
||||
t.Fatalf("error decoding output: %s", err)
|
||||
}
|
||||
if docs2, err = decodeToMapSlice(tc.want); err != nil {
|
||||
t.Fatalf("Error decoding want: %s", err)
|
||||
}
|
||||
if !reflect.DeepEqual(docs1, docs2) {
|
||||
t.Fatalf("incorrect output: got %s, want %s", string(output), tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func decodeToMapSlice(data string) ([]yaml.MapSlice, error) {
|
||||
// ensures that the order is correct
|
||||
var docs []yaml.MapSlice
|
||||
decoder := yaml.NewDecoder(strings.NewReader(data))
|
||||
for {
|
||||
var doc yaml.MapSlice
|
||||
err := decoder.Decode(&doc)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
func TestParseToolFile(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
@@ -810,7 +505,7 @@ func TestParseToolFile(t *testing.T) {
|
||||
wantToolsFile ToolsFile
|
||||
}{
|
||||
{
|
||||
description: "basic example tools file v1",
|
||||
description: "basic example",
|
||||
in: `
|
||||
sources:
|
||||
my-pg-instance:
|
||||
@@ -840,7 +535,7 @@ func TestParseToolFile(t *testing.T) {
|
||||
Sources: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpgsrc.Config{
|
||||
Name: "my-pg-instance",
|
||||
Type: cloudsqlpgsrc.SourceType,
|
||||
Kind: cloudsqlpgsrc.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -853,7 +548,7 @@ func TestParseToolFile(t *testing.T) {
|
||||
Tools: server.ToolConfigs{
|
||||
"example_tool": postgressql.Config{
|
||||
Name: "example_tool",
|
||||
Type: "postgres-sql",
|
||||
Kind: "postgres-sql",
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
@@ -873,121 +568,7 @@ func TestParseToolFile(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "basic example tools file v2",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
---
|
||||
kind: authServices
|
||||
name: my-google-auth
|
||||
type: google
|
||||
clientId: testing-id
|
||||
---
|
||||
kind: embeddingModels
|
||||
name: gemini-model
|
||||
type: gemini
|
||||
model: gemini-embedding-001
|
||||
apiKey: some-key
|
||||
dimension: 768
|
||||
---
|
||||
kind: tools
|
||||
name: example_tool
|
||||
type: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
---
|
||||
kind: toolsets
|
||||
name: example_toolset
|
||||
tools:
|
||||
- example_tool
|
||||
---
|
||||
kind: prompts
|
||||
name: code_review
|
||||
description: ask llm to analyze code quality
|
||||
messages:
|
||||
- content: "please review the following code for quality: {{.code}}"
|
||||
arguments:
|
||||
- name: code
|
||||
description: the code to review
|
||||
`,
|
||||
wantToolsFile: ToolsFile{
|
||||
Sources: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpgsrc.Config{
|
||||
Name: "my-pg-instance",
|
||||
Type: cloudsqlpgsrc.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
IPType: "public",
|
||||
Database: "my_db",
|
||||
User: "my_user",
|
||||
Password: "my_pass",
|
||||
},
|
||||
},
|
||||
AuthServices: server.AuthServiceConfigs{
|
||||
"my-google-auth": google.Config{
|
||||
Name: "my-google-auth",
|
||||
Type: google.AuthServiceType,
|
||||
ClientID: "testing-id",
|
||||
},
|
||||
},
|
||||
EmbeddingModels: server.EmbeddingModelConfigs{
|
||||
"gemini-model": gemini.Config{
|
||||
Name: "gemini-model",
|
||||
Type: gemini.EmbeddingModelType,
|
||||
Model: "gemini-embedding-001",
|
||||
ApiKey: "some-key",
|
||||
Dimension: 768,
|
||||
},
|
||||
},
|
||||
Tools: server.ToolConfigs{
|
||||
"example_tool": postgressql.Config{
|
||||
Name: "example_tool",
|
||||
Type: "postgres-sql",
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
Parameters: []parameters.Parameter{
|
||||
parameters.NewStringParameter("country", "some description"),
|
||||
},
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
Toolsets: server.ToolsetConfigs{
|
||||
"example_toolset": tools.ToolsetConfig{
|
||||
Name: "example_toolset",
|
||||
ToolNames: []string{"example_tool"},
|
||||
},
|
||||
},
|
||||
Prompts: server.PromptConfigs{
|
||||
"code_review": custom.Config{
|
||||
Name: "code_review",
|
||||
Description: "ask llm to analyze code quality",
|
||||
Arguments: prompts.Arguments{
|
||||
{Parameter: parameters.NewStringParameter("code", "the code to review")},
|
||||
},
|
||||
Messages: []prompts.Message{
|
||||
{Role: "user", Content: "please review the following code for quality: {{.code}}"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "only prompts",
|
||||
description: "with prompts example",
|
||||
in: `
|
||||
prompts:
|
||||
my-prompt:
|
||||
@@ -1108,7 +689,7 @@ func TestParseToolFileWithAuth(t *testing.T) {
|
||||
Sources: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpgsrc.Config{
|
||||
Name: "my-pg-instance",
|
||||
Type: cloudsqlpgsrc.SourceType,
|
||||
Kind: cloudsqlpgsrc.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -1121,19 +702,19 @@ func TestParseToolFileWithAuth(t *testing.T) {
|
||||
AuthServices: server.AuthServiceConfigs{
|
||||
"my-google-service": google.Config{
|
||||
Name: "my-google-service",
|
||||
Type: google.AuthServiceType,
|
||||
Kind: google.AuthServiceKind,
|
||||
ClientID: "my-client-id",
|
||||
},
|
||||
"other-google-service": google.Config{
|
||||
Name: "other-google-service",
|
||||
Type: google.AuthServiceType,
|
||||
Kind: google.AuthServiceKind,
|
||||
ClientID: "other-client-id",
|
||||
},
|
||||
},
|
||||
Tools: server.ToolConfigs{
|
||||
"example_tool": postgressql.Config{
|
||||
Name: "example_tool",
|
||||
Type: "postgres-sql",
|
||||
Kind: "postgres-sql",
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
@@ -1208,7 +789,7 @@ func TestParseToolFileWithAuth(t *testing.T) {
|
||||
Sources: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpgsrc.Config{
|
||||
Name: "my-pg-instance",
|
||||
Type: cloudsqlpgsrc.SourceType,
|
||||
Kind: cloudsqlpgsrc.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -1218,22 +799,22 @@ func TestParseToolFileWithAuth(t *testing.T) {
|
||||
Password: "my_pass",
|
||||
},
|
||||
},
|
||||
AuthServices: server.AuthServiceConfigs{
|
||||
AuthSources: server.AuthServiceConfigs{
|
||||
"my-google-service": google.Config{
|
||||
Name: "my-google-service",
|
||||
Type: google.AuthServiceType,
|
||||
Kind: google.AuthServiceKind,
|
||||
ClientID: "my-client-id",
|
||||
},
|
||||
"other-google-service": google.Config{
|
||||
Name: "other-google-service",
|
||||
Type: google.AuthServiceType,
|
||||
Kind: google.AuthServiceKind,
|
||||
ClientID: "other-client-id",
|
||||
},
|
||||
},
|
||||
Tools: server.ToolConfigs{
|
||||
"example_tool": postgressql.Config{
|
||||
Name: "example_tool",
|
||||
Type: "postgres-sql",
|
||||
Kind: "postgres-sql",
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
@@ -1310,7 +891,7 @@ func TestParseToolFileWithAuth(t *testing.T) {
|
||||
Sources: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpgsrc.Config{
|
||||
Name: "my-pg-instance",
|
||||
Type: cloudsqlpgsrc.SourceType,
|
||||
Kind: cloudsqlpgsrc.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -1323,19 +904,19 @@ func TestParseToolFileWithAuth(t *testing.T) {
|
||||
AuthServices: server.AuthServiceConfigs{
|
||||
"my-google-service": google.Config{
|
||||
Name: "my-google-service",
|
||||
Type: google.AuthServiceType,
|
||||
Kind: google.AuthServiceKind,
|
||||
ClientID: "my-client-id",
|
||||
},
|
||||
"other-google-service": google.Config{
|
||||
Name: "other-google-service",
|
||||
Type: google.AuthServiceType,
|
||||
Kind: google.AuthServiceKind,
|
||||
ClientID: "other-client-id",
|
||||
},
|
||||
},
|
||||
Tools: server.ToolConfigs{
|
||||
"example_tool": postgressql.Config{
|
||||
Name: "example_tool",
|
||||
Type: "postgres-sql",
|
||||
Kind: "postgres-sql",
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
@@ -1481,7 +1062,7 @@ func TestEnvVarReplacement(t *testing.T) {
|
||||
Sources: server.SourceConfigs{
|
||||
"my-http-instance": httpsrc.Config{
|
||||
Name: "my-http-instance",
|
||||
Type: httpsrc.SourceType,
|
||||
Kind: httpsrc.SourceKind,
|
||||
BaseURL: "http://test_server/",
|
||||
Timeout: "10s",
|
||||
DefaultHeaders: map[string]string{"Authorization": "ACTUAL_HEADER"},
|
||||
@@ -1491,19 +1072,19 @@ func TestEnvVarReplacement(t *testing.T) {
|
||||
AuthServices: server.AuthServiceConfigs{
|
||||
"my-google-service": google.Config{
|
||||
Name: "my-google-service",
|
||||
Type: google.AuthServiceType,
|
||||
Kind: google.AuthServiceKind,
|
||||
ClientID: "ACTUAL_CLIENT_ID",
|
||||
},
|
||||
"other-google-service": google.Config{
|
||||
Name: "other-google-service",
|
||||
Type: google.AuthServiceType,
|
||||
Kind: google.AuthServiceKind,
|
||||
ClientID: "ACTUAL_CLIENT_ID_2",
|
||||
},
|
||||
},
|
||||
Tools: server.ToolConfigs{
|
||||
"example_tool": http.Config{
|
||||
Name: "example_tool",
|
||||
Type: "http",
|
||||
Kind: "http",
|
||||
Source: "my-instance",
|
||||
Method: "GET",
|
||||
Path: "search?name=alice&pet=cat",
|
||||
@@ -1912,7 +1493,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"cloud_sql_postgres_admin_tools": tools.ToolsetConfig{
|
||||
Name: "cloud_sql_postgres_admin_tools",
|
||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance", "create_backup"},
|
||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1922,7 +1503,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"cloud_sql_mysql_admin_tools": tools.ToolsetConfig{
|
||||
Name: "cloud_sql_mysql_admin_tools",
|
||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup"},
|
||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1932,7 +1513,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"cloud_sql_mssql_admin_tools": tools.ToolsetConfig{
|
||||
Name: "cloud_sql_mssql_admin_tools",
|
||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup"},
|
||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -509,7 +509,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install toolbox-core --quiet\n",
|
||||
"! pip install toolbox-adk --quiet\n",
|
||||
"! pip install google-adk --quiet"
|
||||
]
|
||||
},
|
||||
@@ -525,14 +525,18 @@
|
||||
"from google.adk.runners import Runner\n",
|
||||
"from google.adk.sessions import InMemorySessionService\n",
|
||||
"from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService\n",
|
||||
"from google.adk.tools.toolbox_toolset import ToolboxToolset\n",
|
||||
"from google.genai import types\n",
|
||||
"from toolbox_core import ToolboxSyncClient\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"# TODO(developer): replace this with your Google API key\n",
|
||||
"os.environ['GOOGLE_API_KEY'] = \"<GOOGLE_API_KEY>\"\n",
|
||||
"\n",
|
||||
"toolbox_client = ToolboxSyncClient(\"http://127.0.0.1:5000\")\n",
|
||||
"# Configure toolset\n",
|
||||
"toolset = ToolboxToolset(\n",
|
||||
" server_url=\"http://127.0.0.1:5000\",\n",
|
||||
" toolset_name=\"my-toolset\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"prompt = \"\"\"\n",
|
||||
" You're a helpful hotel assistant. You handle hotel searching, booking and\n",
|
||||
@@ -549,7 +553,7 @@
|
||||
" name='hotel_agent',\n",
|
||||
" description='A helpful AI assistant.',\n",
|
||||
" instruction=prompt,\n",
|
||||
" tools=toolbox_client.load_toolset(\"my-toolset\"),\n",
|
||||
" tools=[toolset],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"session_service = InMemorySessionService()\n",
|
||||
|
||||
@@ -52,7 +52,7 @@ runtime](https://research.google.com/colaboratory/local-runtimes.html).
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="ADK" lang="bash" >}}
|
||||
|
||||
pip install toolbox-core
|
||||
pip install toolbox-adk
|
||||
{{< /tab >}}
|
||||
{{< tab header="Langchain" lang="bash" >}}
|
||||
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
from google.adk import Agent
|
||||
from google.adk.apps import App
|
||||
from toolbox_core import ToolboxSyncClient
|
||||
from google.adk.tools.toolbox_toolset import ToolboxToolset
|
||||
|
||||
# TODO(developer): update the TOOLBOX_URL to your toolbox endpoint
|
||||
client = ToolboxSyncClient("http://127.0.0.1:5000")
|
||||
toolset = ToolboxToolset(
|
||||
server_url="http://127.0.0.1:5000",
|
||||
)
|
||||
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model='gemini-2.5-flash',
|
||||
instruction="You are a helpful AI assistant designed to provide accurate and useful information.",
|
||||
tools=client.load_toolset(),
|
||||
tools=[toolset],
|
||||
)
|
||||
|
||||
app = App(root_agent=root_agent, name="my_agent")
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
google-adk==1.21.0
|
||||
toolbox-core==0.5.4
|
||||
toolbox-adk>=0.1.0
|
||||
pytest==9.0.2
|
||||
@@ -48,7 +48,6 @@ instance, database and users:
|
||||
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
|
||||
* All `viewer` tools
|
||||
* `create_database`
|
||||
* `create_backup`
|
||||
* `roles/cloudsql.admin`: Provides full control over all resources.
|
||||
* All `editor` and `viewer` tools
|
||||
* `create_instance`
|
||||
@@ -300,7 +299,6 @@ instances and interacting with your database:
|
||||
* **create_user**: Creates a new user in a Cloud SQL instance.
|
||||
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
||||
* **clone_instance**: Creates a clone of an existing Cloud SQL for SQL Server instance.
|
||||
* **create_backup**: Creates a backup on a Cloud SQL instance.
|
||||
|
||||
{{< notice note >}}
|
||||
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
||||
|
||||
@@ -48,7 +48,6 @@ database and users:
|
||||
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
|
||||
* All `viewer` tools
|
||||
* `create_database`
|
||||
* `create_backup`
|
||||
* `roles/cloudsql.admin`: Provides full control over all resources.
|
||||
* All `editor` and `viewer` tools
|
||||
* `create_instance`
|
||||
@@ -300,7 +299,6 @@ instances and interacting with your database:
|
||||
* **create_user**: Creates a new user in a Cloud SQL instance.
|
||||
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
||||
* **clone_instance**: Creates a clone of an existing Cloud SQL for MySQL instance.
|
||||
* **create_backup**: Creates a backup on a Cloud SQL instance.
|
||||
|
||||
{{< notice note >}}
|
||||
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
||||
|
||||
@@ -48,7 +48,6 @@ instance, database and users:
|
||||
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
|
||||
* All `viewer` tools
|
||||
* `create_database`
|
||||
* `create_backup`
|
||||
* `roles/cloudsql.admin`: Provides full control over all resources.
|
||||
* All `editor` and `viewer` tools
|
||||
* `create_instance`
|
||||
@@ -300,7 +299,6 @@ instances and interacting with your database:
|
||||
* **create_user**: Creates a new user in a Cloud SQL instance.
|
||||
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
||||
* **clone_instance**: Creates a clone of an existing Cloud SQL for PostgreSQL instance.
|
||||
* **create_backup**: Creates a backup on a Cloud SQL instance.
|
||||
|
||||
{{< notice note >}}
|
||||
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
||||
|
||||
@@ -49,7 +49,7 @@ with the necessary configuration for deployment to Vertex AI Agent Engine.
|
||||
4. Add `toolbox-core` as a dependency to the new project:
|
||||
|
||||
```bash
|
||||
uv add toolbox-core
|
||||
uv add toolbox-adk
|
||||
```
|
||||
|
||||
## Step 3: Configure Google Cloud Authentication
|
||||
@@ -95,22 +95,23 @@ authentication token.
|
||||
```python
|
||||
from google.adk import Agent
|
||||
from google.adk.apps import App
|
||||
from toolbox_core import ToolboxSyncClient, auth_methods
|
||||
from google.adk.tools.toolbox_toolset import ToolboxToolset
|
||||
from toolbox_adk import CredentialStrategy
|
||||
|
||||
# TODO(developer): Replace with your Toolbox Cloud Run Service URL
|
||||
TOOLBOX_URL = "https://your-toolbox-service-xyz.a.run.app"
|
||||
|
||||
# Initialize the client with the Cloud Run URL and Auth headers
|
||||
client = ToolboxSyncClient(
|
||||
TOOLBOX_URL,
|
||||
client_headers={"Authorization": auth_methods.get_google_id_token(TOOLBOX_URL)}
|
||||
# Initialize the toolset with Workload Identity (generates ID token for the URL)
|
||||
toolset = ToolboxToolset(
|
||||
server_url=TOOLBOX_URL,
|
||||
credentials=CredentialStrategy.workload_identity(target_audience=TOOLBOX_URL)
|
||||
)
|
||||
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model='gemini-2.5-flash',
|
||||
instruction="You are a helpful AI assistant designed to provide accurate and useful information.",
|
||||
tools=client.load_toolset(),
|
||||
tools=[toolset],
|
||||
)
|
||||
|
||||
app = App(root_agent=root_agent, name="my_agent")
|
||||
|
||||
@@ -187,7 +187,6 @@ See [Usage Examples](../reference/cli.md#examples).
|
||||
manage existing resources.
|
||||
* All `viewer` tools
|
||||
* `create_database`
|
||||
* `create_backup`
|
||||
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
|
||||
all resources.
|
||||
* All `editor` and `viewer` tools
|
||||
@@ -204,7 +203,6 @@ See [Usage Examples](../reference/cli.md#examples).
|
||||
* `create_user`: Creates a new user in a Cloud SQL instance.
|
||||
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
||||
* `clone_instance`: Creates a clone for an existing Cloud SQL for MySQL instance.
|
||||
* `create_backup`: Creates a backup on a Cloud SQL instance.
|
||||
|
||||
## Cloud SQL for PostgreSQL
|
||||
|
||||
@@ -277,7 +275,6 @@ See [Usage Examples](../reference/cli.md#examples).
|
||||
manage existing resources.
|
||||
* All `viewer` tools
|
||||
* `create_database`
|
||||
* `create_backup`
|
||||
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
|
||||
all resources.
|
||||
* All `editor` and `viewer` tools
|
||||
@@ -293,7 +290,6 @@ See [Usage Examples](../reference/cli.md#examples).
|
||||
* `create_user`: Creates a new user in a Cloud SQL instance.
|
||||
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
||||
* `clone_instance`: Creates a clone for an existing Cloud SQL for PostgreSQL instance.
|
||||
* `create_backup`: Creates a backup on a Cloud SQL instance.
|
||||
|
||||
## Cloud SQL for SQL Server
|
||||
|
||||
@@ -340,7 +336,6 @@ See [Usage Examples](../reference/cli.md#examples).
|
||||
manage existing resources.
|
||||
* All `viewer` tools
|
||||
* `create_database`
|
||||
* `create_backup`
|
||||
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
|
||||
all resources.
|
||||
* All `editor` and `viewer` tools
|
||||
@@ -356,7 +351,6 @@ See [Usage Examples](../reference/cli.md#examples).
|
||||
* `create_user`: Creates a new user in a Cloud SQL instance.
|
||||
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
||||
* `clone_instance`: Creates a clone for an existing Cloud SQL for SQL Server instance.
|
||||
* `create_backup`: Creates a backup on a Cloud SQL instance.
|
||||
|
||||
## Dataplex
|
||||
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
---
|
||||
title: cloud-sql-create-backup
|
||||
type: docs
|
||||
weight: 10
|
||||
description: "Creates a backup on a Cloud SQL instance."
|
||||
---
|
||||
|
||||
The `cloud-sql-create-backup` tool creates an on-demand backup on a Cloud SQL instance using the Cloud SQL Admin API.
|
||||
|
||||
{{< notice info dd>}}
|
||||
This tool uses a `source` of kind `cloud-sql-admin`.
|
||||
{{< /notice >}}
|
||||
|
||||
## Examples
|
||||
|
||||
Basic backup creation (current state)
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
backup-creation-basic:
|
||||
kind: cloud-sql-create-backup
|
||||
source: cloud-sql-admin-source
|
||||
description: "Creates a backup on the given Cloud SQL instance."
|
||||
```
|
||||
## Reference
|
||||
### Tool Configuration
|
||||
| **field** | **type** | **required** | **description** |
|
||||
| -------------- | :------: | :----------: | ------------------------------------------------------------- |
|
||||
| kind | string | true | Must be "cloud-sql-create-backup". |
|
||||
| source | string | true | The name of the `cloud-sql-admin` source to use. |
|
||||
| description | string | false | A description of the tool. |
|
||||
|
||||
### Tool Inputs
|
||||
|
||||
| **parameter** | **type** | **required** | **description** |
|
||||
| -------------------------- | :------: | :----------: | ------------------------------------------------------------------------------- |
|
||||
| project | string | true | The project ID. |
|
||||
| instance | string | true | The name of the instance to take a backup on. Does not include the project ID. |
|
||||
| location | string | false | (Optional) Location of the backup run. |
|
||||
| backup_description | string | false | (Optional) The description of this backup run. |
|
||||
|
||||
## See Also
|
||||
- [Cloud SQL Admin API documentation](https://cloud.google.com/sql/docs/mysql/admin-api)
|
||||
- [Toolbox Cloud SQL tools documentation](../cloudsql)
|
||||
- [Cloud SQL Backup API documentation](https://cloud.google.com/sql/docs/mysql/backup-recovery/backups)
|
||||
@@ -365,7 +365,7 @@ pip install llama-index-llms-google-genai
|
||||
|
||||
{{< /tab >}}
|
||||
{{< tab header="ADK" lang="bash" >}}
|
||||
pip install toolbox-core
|
||||
pip install toolbox-adk
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
|
||||
@@ -607,8 +607,8 @@ from google.adk.agents import Agent
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.sessions import InMemorySessionService
|
||||
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from google.adk.tools.toolbox_toolset import ToolboxToolset
|
||||
from google.genai import types # For constructing message content
|
||||
from toolbox_core import ToolboxSyncClient
|
||||
|
||||
import os
|
||||
os.environ['GOOGLE_GENAI_USE_VERTEXAI'] = 'True'
|
||||
@@ -623,48 +623,47 @@ os.environ['GOOGLE_CLOUD_LOCATION'] = 'us-central1'
|
||||
|
||||
# --- Load Tools from Toolbox ---
|
||||
|
||||
# TODO(developer): Ensure the Toolbox server is running at <http://127.0.0.1:5000>
|
||||
# TODO(developer): Ensure the Toolbox server is running at http://127.0.0.1:5000
|
||||
toolset = ToolboxToolset(server_url="http://127.0.0.1:5000")
|
||||
|
||||
with ToolboxSyncClient("<http://127.0.0.1:5000>") as toolbox_client:
|
||||
# TODO(developer): Replace "my-toolset" with the actual ID of your toolset as configured in your MCP Toolbox server.
|
||||
agent_toolset = toolbox_client.load_toolset("my-toolset")
|
||||
# --- Define the Agent's Prompt ---
|
||||
prompt = """
|
||||
You're a helpful hotel assistant. You handle hotel searching, booking and
|
||||
cancellations. When the user searches for a hotel, mention it's name, id,
|
||||
location and price tier. Always mention hotel ids while performing any
|
||||
searches. This is very important for any operations. For any bookings or
|
||||
cancellations, please provide the appropriate confirmation. Be sure to
|
||||
update checkin or checkout dates if mentioned by the user.
|
||||
Don't ask for confirmations from the user.
|
||||
"""
|
||||
|
||||
# --- Define the Agent's Prompt ---
|
||||
prompt = """
|
||||
You're a helpful hotel assistant. You handle hotel searching, booking and
|
||||
cancellations. When the user searches for a hotel, mention it's name, id,
|
||||
location and price tier. Always mention hotel ids while performing any
|
||||
searches. This is very important for any operations. For any bookings or
|
||||
cancellations, please provide the appropriate confirmation. Be sure to
|
||||
update checkin or checkout dates if mentioned by the user.
|
||||
Don't ask for confirmations from the user.
|
||||
"""
|
||||
# --- Configure the Agent ---
|
||||
|
||||
# --- Configure the Agent ---
|
||||
root_agent = Agent(
|
||||
model='gemini-2.0-flash-001',
|
||||
name='hotel_agent',
|
||||
description='A helpful AI assistant that can search and book hotels.',
|
||||
instruction=prompt,
|
||||
tools=[toolset], # Pass the loaded toolset
|
||||
)
|
||||
|
||||
root_agent = Agent(
|
||||
model='gemini-2.0-flash-001',
|
||||
name='hotel_agent',
|
||||
description='A helpful AI assistant that can search and book hotels.',
|
||||
instruction=prompt,
|
||||
tools=agent_toolset, # Pass the loaded toolset
|
||||
)
|
||||
# --- Initialize Services for Running the Agent ---
|
||||
session_service = InMemorySessionService()
|
||||
artifacts_service = InMemoryArtifactService()
|
||||
|
||||
# --- Initialize Services for Running the Agent ---
|
||||
session_service = InMemorySessionService()
|
||||
artifacts_service = InMemoryArtifactService()
|
||||
runner = Runner(
|
||||
app_name='hotel_agent',
|
||||
agent=root_agent,
|
||||
artifact_service=artifacts_service,
|
||||
session_service=session_service,
|
||||
)
|
||||
|
||||
async def main():
|
||||
# Create a new session for the interaction.
|
||||
session = session_service.create_session(
|
||||
session = await session_service.create_session(
|
||||
state={}, app_name='hotel_agent', user_id='123'
|
||||
)
|
||||
|
||||
runner = Runner(
|
||||
app_name='hotel_agent',
|
||||
agent=root_agent,
|
||||
artifact_service=artifacts_service,
|
||||
session_service=session_service,
|
||||
)
|
||||
|
||||
# --- Define Queries and Run the Agent ---
|
||||
queries = [
|
||||
"Find hotels in Basel with Basel in it's name.",
|
||||
@@ -687,6 +686,10 @@ with ToolboxSyncClient("<http://127.0.0.1:5000>") as toolbox_client:
|
||||
|
||||
for text in responses:
|
||||
print(text)
|
||||
|
||||
import asyncio
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
|
||||
|
||||
@@ -21,13 +21,13 @@ import (
|
||||
|
||||
// AuthServiceConfig is the interface for configuring authentication services.
|
||||
type AuthServiceConfig interface {
|
||||
AuthServiceConfigType() string
|
||||
AuthServiceConfigKind() string
|
||||
Initialize() (AuthService, error)
|
||||
}
|
||||
|
||||
// AuthService is the interface for authentication services.
|
||||
type AuthService interface {
|
||||
AuthServiceType() string
|
||||
AuthServiceKind() string
|
||||
GetName() string
|
||||
GetClaimsFromHeader(context.Context, http.Header) (map[string]any, error)
|
||||
ToConfig() AuthServiceConfig
|
||||
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
"google.golang.org/api/idtoken"
|
||||
)
|
||||
|
||||
const AuthServiceType string = "google"
|
||||
const AuthServiceKind string = "google"
|
||||
|
||||
// validate interface
|
||||
var _ auth.AuthServiceConfig = Config{}
|
||||
@@ -31,13 +31,13 @@ var _ auth.AuthServiceConfig = Config{}
|
||||
// Auth service configuration
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
ClientID string `yaml:"clientId" validate:"required"`
|
||||
}
|
||||
|
||||
// Returns the auth service type
|
||||
func (cfg Config) AuthServiceConfigType() string {
|
||||
return AuthServiceType
|
||||
// Returns the auth service kind
|
||||
func (cfg Config) AuthServiceConfigKind() string {
|
||||
return AuthServiceKind
|
||||
}
|
||||
|
||||
// Initialize a Google auth service
|
||||
@@ -55,9 +55,9 @@ type AuthService struct {
|
||||
Config
|
||||
}
|
||||
|
||||
// Returns the auth service type
|
||||
func (a AuthService) AuthServiceType() string {
|
||||
return AuthServiceType
|
||||
// Returns the auth service kind
|
||||
func (a AuthService) AuthServiceKind() string {
|
||||
return AuthServiceKind
|
||||
}
|
||||
|
||||
func (a AuthService) ToConfig() auth.AuthServiceConfig {
|
||||
|
||||
@@ -22,12 +22,12 @@ import (
|
||||
|
||||
// EmbeddingModelConfig is the interface for configuring embedding models.
|
||||
type EmbeddingModelConfig interface {
|
||||
EmbeddingModelConfigType() string
|
||||
EmbeddingModelConfigKind() string
|
||||
Initialize(context.Context) (EmbeddingModel, error)
|
||||
}
|
||||
|
||||
type EmbeddingModel interface {
|
||||
EmbeddingModelType() string
|
||||
EmbeddingModelKind() string
|
||||
ToConfig() EmbeddingModelConfig
|
||||
EmbedParameters(context.Context, []string) ([][]float32, error)
|
||||
}
|
||||
|
||||
@@ -23,22 +23,22 @@ import (
|
||||
"google.golang.org/genai"
|
||||
)
|
||||
|
||||
const EmbeddingModelType string = "gemini"
|
||||
const EmbeddingModelKind string = "gemini"
|
||||
|
||||
// validate interface
|
||||
var _ embeddingmodels.EmbeddingModelConfig = Config{}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Model string `yaml:"model" validate:"required"`
|
||||
ApiKey string `yaml:"apiKey"`
|
||||
Dimension int32 `yaml:"dimension"`
|
||||
}
|
||||
|
||||
// Returns the embedding model type
|
||||
func (cfg Config) EmbeddingModelConfigType() string {
|
||||
return EmbeddingModelType
|
||||
// Returns the embedding model kind
|
||||
func (cfg Config) EmbeddingModelConfigKind() string {
|
||||
return EmbeddingModelKind
|
||||
}
|
||||
|
||||
// Initialize a Gemini embedding model
|
||||
@@ -69,9 +69,9 @@ type EmbeddingModel struct {
|
||||
Config
|
||||
}
|
||||
|
||||
// Returns the embedding model type
|
||||
func (m EmbeddingModel) EmbeddingModelType() string {
|
||||
return EmbeddingModelType
|
||||
// Returns the embedding model kind
|
||||
func (m EmbeddingModel) EmbeddingModelKind() string {
|
||||
return EmbeddingModelKind
|
||||
}
|
||||
|
||||
func (m EmbeddingModel) ToConfig() embeddingmodels.EmbeddingModelConfig {
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
package gemini_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini"
|
||||
@@ -34,15 +34,15 @@ func TestParseFromYamlGemini(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: embeddingModels
|
||||
name: my-gemini-model
|
||||
type: gemini
|
||||
model: text-embedding-004
|
||||
embeddingModels:
|
||||
my-gemini-model:
|
||||
kind: gemini
|
||||
model: text-embedding-004
|
||||
`,
|
||||
want: map[string]embeddingmodels.EmbeddingModelConfig{
|
||||
"my-gemini-model": gemini.Config{
|
||||
Name: "my-gemini-model",
|
||||
Type: gemini.EmbeddingModelType,
|
||||
Kind: gemini.EmbeddingModelKind,
|
||||
Model: "text-embedding-004",
|
||||
},
|
||||
},
|
||||
@@ -50,17 +50,17 @@ func TestParseFromYamlGemini(t *testing.T) {
|
||||
{
|
||||
desc: "full example with optional fields",
|
||||
in: `
|
||||
kind: embeddingModels
|
||||
name: complex-gemini
|
||||
type: gemini
|
||||
model: text-embedding-004
|
||||
apiKey: "test-api-key"
|
||||
dimension: 768
|
||||
embeddingModels:
|
||||
complex-gemini:
|
||||
kind: gemini
|
||||
model: text-embedding-004
|
||||
apiKey: "test-api-key"
|
||||
dimension: 768
|
||||
`,
|
||||
want: map[string]embeddingmodels.EmbeddingModelConfig{
|
||||
"complex-gemini": gemini.Config{
|
||||
Name: "complex-gemini",
|
||||
Type: gemini.EmbeddingModelType,
|
||||
Kind: gemini.EmbeddingModelKind,
|
||||
Model: "text-embedding-004",
|
||||
ApiKey: "test-api-key",
|
||||
Dimension: 768,
|
||||
@@ -70,13 +70,16 @@ func TestParseFromYamlGemini(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Models server.EmbeddingModelConfigs `yaml:"embeddingModels"`
|
||||
}{}
|
||||
// Parse contents
|
||||
_, _, got, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: %v", cmp.Diff(tc.want, got))
|
||||
if !cmp.Equal(tc.want, got.Models) {
|
||||
t.Fatalf("incorrect parse: %v", cmp.Diff(tc.want, got.Models))
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -90,29 +93,32 @@ func TestFailParseFromYamlGemini(t *testing.T) {
|
||||
{
|
||||
desc: "missing required model field",
|
||||
in: `
|
||||
kind: embeddingModels
|
||||
name: bad-model
|
||||
type: gemini
|
||||
embeddingModels:
|
||||
bad-model:
|
||||
kind: gemini
|
||||
`,
|
||||
// Removed the specific model name from the prefix to match your output
|
||||
err: "error unmarshaling embeddingModels: unable to parse as \"bad-model\": Key: 'Config.Model' Error:Field validation for 'Model' failed on the 'required' tag",
|
||||
err: "unable to parse as \"gemini\": Key: 'Config.Model' Error:Field validation for 'Model' failed on the 'required' tag",
|
||||
},
|
||||
{
|
||||
desc: "unknown field",
|
||||
in: `
|
||||
kind: embeddingModels
|
||||
name: bad-field
|
||||
type: gemini
|
||||
model: text-embedding-004
|
||||
invalid_param: true
|
||||
embeddingModels:
|
||||
bad-field:
|
||||
kind: gemini
|
||||
model: text-embedding-004
|
||||
invalid_param: true
|
||||
`,
|
||||
// Updated to match the specific line-starting format of your error output
|
||||
err: "error unmarshaling embeddingModels: unable to parse as \"bad-field\": [1:1] unknown field \"invalid_param\"\n> 1 | invalid_param: true\n ^\n 2 | model: text-embedding-004\n 3 | name: bad-field\n 4 | type: gemini",
|
||||
err: "unable to parse as \"gemini\": [1:1] unknown field \"invalid_param\"\n> 1 | invalid_param: true\n ^\n 2 | kind: gemini\n 3 | model: text-embedding-004",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Models server.EmbeddingModelConfigs `yaml:"embeddingModels"`
|
||||
}{}
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -43,9 +43,6 @@ tools:
|
||||
clone_instance:
|
||||
kind: cloud-sql-clone-instance
|
||||
source: cloud-sql-admin-source
|
||||
create_backup:
|
||||
kind: cloud-sql-create-backup
|
||||
source: cloud-sql-admin-source
|
||||
|
||||
toolsets:
|
||||
cloud_sql_mssql_admin_tools:
|
||||
@@ -57,4 +54,3 @@ toolsets:
|
||||
- create_user
|
||||
- wait_for_operation
|
||||
- clone_instance
|
||||
- create_backup
|
||||
|
||||
@@ -43,9 +43,6 @@ tools:
|
||||
clone_instance:
|
||||
kind: cloud-sql-clone-instance
|
||||
source: cloud-sql-admin-source
|
||||
create_backup:
|
||||
kind: cloud-sql-create-backup
|
||||
source: cloud-sql-admin-source
|
||||
|
||||
toolsets:
|
||||
cloud_sql_mysql_admin_tools:
|
||||
@@ -57,4 +54,3 @@ toolsets:
|
||||
- create_user
|
||||
- wait_for_operation
|
||||
- clone_instance
|
||||
- create_backup
|
||||
|
||||
@@ -46,9 +46,6 @@ tools:
|
||||
postgres_upgrade_precheck:
|
||||
kind: postgres-upgrade-precheck
|
||||
source: cloud-sql-admin-source
|
||||
create_backup:
|
||||
kind: cloud-sql-create-backup
|
||||
source: cloud-sql-admin-source
|
||||
|
||||
toolsets:
|
||||
cloud_sql_postgres_admin_tools:
|
||||
@@ -61,4 +58,3 @@ toolsets:
|
||||
- wait_for_operation
|
||||
- postgres_upgrade_precheck
|
||||
- clone_instance
|
||||
- create_backup
|
||||
|
||||
@@ -25,12 +25,12 @@ import (
|
||||
|
||||
type Message = prompts.Message
|
||||
|
||||
const resourceType = "custom"
|
||||
const kind = "custom"
|
||||
|
||||
// init registers this prompt type with the prompt framework.
|
||||
// init registers this prompt kind with the prompt framework.
|
||||
func init() {
|
||||
if !prompts.Register(resourceType, newConfig) {
|
||||
panic(fmt.Sprintf("prompt type %q already registered", resourceType))
|
||||
if !prompts.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("prompt kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,8 +56,8 @@ type Config struct {
|
||||
var _ prompts.PromptConfig = Config{}
|
||||
var _ prompts.Prompt = Prompt{}
|
||||
|
||||
func (c Config) PromptConfigType() string {
|
||||
return resourceType
|
||||
func (c Config) PromptConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (c Config) Initialize() (prompts.Prompt, error) {
|
||||
|
||||
@@ -42,7 +42,7 @@ func TestConfig(t *testing.T) {
|
||||
Arguments: testArgs,
|
||||
}
|
||||
|
||||
// initialize and check type
|
||||
// initialize and check kind
|
||||
p, err := cfg.Initialize()
|
||||
if err != nil {
|
||||
t.Fatalf("Initialize() failed: %v", err)
|
||||
@@ -50,8 +50,8 @@ func TestConfig(t *testing.T) {
|
||||
if p == nil {
|
||||
t.Fatal("Initialize() returned a nil prompt")
|
||||
}
|
||||
if cfg.PromptConfigType() != "custom" {
|
||||
t.Errorf("PromptConfigType() = %q, want %q", cfg.PromptConfigType(), "custom")
|
||||
if cfg.PromptConfigKind() != "custom" {
|
||||
t.Errorf("PromptConfigKind() = %q, want %q", cfg.PromptConfigKind(), "custom")
|
||||
}
|
||||
|
||||
t.Run("Manifest", func(t *testing.T) {
|
||||
|
||||
@@ -30,40 +30,40 @@ var promptRegistry = make(map[string]PromptConfigFactory)
|
||||
|
||||
// Register allows individual prompt packages to register their configuration
|
||||
// factory function. This is typically called from an init() function in the
|
||||
// prompt's package. It associates a 'type' string with a function that can
|
||||
// prompt's package. It associates a 'kind' string with a function that can
|
||||
// produce the specific PromptConfig type. It returns true if the registration was
|
||||
// successful, and false if a prompt with the same type was already registered.
|
||||
func Register(resourceType string, factory PromptConfigFactory) bool {
|
||||
if _, exists := promptRegistry[resourceType]; exists {
|
||||
// Prompt with this type already exists, do not overwrite.
|
||||
// successful, and false if a prompt with the same kind was already registered.
|
||||
func Register(kind string, factory PromptConfigFactory) bool {
|
||||
if _, exists := promptRegistry[kind]; exists {
|
||||
// Prompt with this kind already exists, do not overwrite.
|
||||
return false
|
||||
}
|
||||
promptRegistry[resourceType] = factory
|
||||
promptRegistry[kind] = factory
|
||||
return true
|
||||
}
|
||||
|
||||
// DecodeConfig looks up the registered factory for the given type and uses it
|
||||
// DecodeConfig looks up the registered factory for the given kind and uses it
|
||||
// to decode the prompt configuration.
|
||||
func DecodeConfig(ctx context.Context, resourceType, name string, decoder *yaml.Decoder) (PromptConfig, error) {
|
||||
factory, found := promptRegistry[resourceType]
|
||||
if !found && resourceType == "" {
|
||||
resourceType = "custom"
|
||||
factory, found = promptRegistry[resourceType]
|
||||
func DecodeConfig(ctx context.Context, kind, name string, decoder *yaml.Decoder) (PromptConfig, error) {
|
||||
factory, found := promptRegistry[kind]
|
||||
if !found && kind == "" {
|
||||
kind = "custom"
|
||||
factory, found = promptRegistry[kind]
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil, fmt.Errorf("unknown prompt type: %q", resourceType)
|
||||
return nil, fmt.Errorf("unknown prompt kind: %q", kind)
|
||||
}
|
||||
|
||||
promptConfig, err := factory(ctx, name, decoder)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse prompt %q as resourceType %q: %w", name, resourceType, err)
|
||||
return nil, fmt.Errorf("unable to parse prompt %q as kind %q: %w", name, kind, err)
|
||||
}
|
||||
return promptConfig, nil
|
||||
}
|
||||
|
||||
type PromptConfig interface {
|
||||
PromptConfigType() string
|
||||
PromptConfigKind() string
|
||||
Initialize() (Prompt, error)
|
||||
}
|
||||
|
||||
|
||||
@@ -29,16 +29,16 @@ import (
|
||||
|
||||
type mockPromptConfig struct {
|
||||
name string
|
||||
Type string
|
||||
kind string
|
||||
}
|
||||
|
||||
func (m *mockPromptConfig) PromptConfigType() string { return m.Type }
|
||||
func (m *mockPromptConfig) PromptConfigKind() string { return m.kind }
|
||||
func (m *mockPromptConfig) Initialize() (prompts.Prompt, error) { return nil, nil }
|
||||
|
||||
var errMockFactory = errors.New("mock factory error")
|
||||
|
||||
func mockFactory(ctx context.Context, name string, decoder *yaml.Decoder) (prompts.PromptConfig, error) {
|
||||
return &mockPromptConfig{name: name, Type: "mockType"}, nil
|
||||
return &mockPromptConfig{name: name, kind: "mockKind"}, nil
|
||||
}
|
||||
|
||||
func mockErrorFactory(ctx context.Context, name string, decoder *yaml.Decoder) (prompts.PromptConfig, error) {
|
||||
@@ -50,17 +50,17 @@ func TestRegistry(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("RegisterAndDecodeSuccess", func(t *testing.T) {
|
||||
resourceType := "testTypeSuccess"
|
||||
if !prompts.Register(resourceType, mockFactory) {
|
||||
kind := "testKindSuccess"
|
||||
if !prompts.Register(kind, mockFactory) {
|
||||
t.Fatal("expected registration to succeed")
|
||||
}
|
||||
// This should fail because we are registering a duplicate
|
||||
if prompts.Register(resourceType, mockFactory) {
|
||||
if prompts.Register(kind, mockFactory) {
|
||||
t.Fatal("expected duplicate registration to fail")
|
||||
}
|
||||
|
||||
decoder := yaml.NewDecoder(strings.NewReader(""))
|
||||
config, err := prompts.DecodeConfig(ctx, resourceType, "testPrompt", decoder)
|
||||
config, err := prompts.DecodeConfig(ctx, kind, "testPrompt", decoder)
|
||||
if err != nil {
|
||||
t.Fatalf("expected DecodeConfig to succeed, but got error: %v", err)
|
||||
}
|
||||
@@ -69,25 +69,25 @@ func TestRegistry(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DecodeUnknownType", func(t *testing.T) {
|
||||
t.Run("DecodeUnknownKind", func(t *testing.T) {
|
||||
decoder := yaml.NewDecoder(strings.NewReader(""))
|
||||
_, err := prompts.DecodeConfig(ctx, "unregisteredType", "testPrompt", decoder)
|
||||
_, err := prompts.DecodeConfig(ctx, "unregisteredKind", "testPrompt", decoder)
|
||||
if err == nil {
|
||||
t.Fatal("expected an error for unknown type, but got nil")
|
||||
t.Fatal("expected an error for unknown kind, but got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unknown prompt type") {
|
||||
t.Errorf("expected error to contain 'unknown prompt type', but got: %v", err)
|
||||
if !strings.Contains(err.Error(), "unknown prompt kind") {
|
||||
t.Errorf("expected error to contain 'unknown prompt kind', but got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FactoryReturnsError", func(t *testing.T) {
|
||||
resourceType := "testTypeError"
|
||||
if !prompts.Register(resourceType, mockErrorFactory) {
|
||||
kind := "testKindError"
|
||||
if !prompts.Register(kind, mockErrorFactory) {
|
||||
t.Fatal("expected registration to succeed")
|
||||
}
|
||||
|
||||
decoder := yaml.NewDecoder(strings.NewReader(""))
|
||||
_, err := prompts.DecodeConfig(ctx, resourceType, "testPrompt", decoder)
|
||||
_, err := prompts.DecodeConfig(ctx, kind, "testPrompt", decoder)
|
||||
if err == nil {
|
||||
t.Fatal("expected an error from the factory, but got nil")
|
||||
}
|
||||
@@ -100,13 +100,13 @@ func TestRegistry(t *testing.T) {
|
||||
decoder := yaml.NewDecoder(strings.NewReader("description: A test prompt"))
|
||||
config, err := prompts.DecodeConfig(ctx, "", "testDefaultPrompt", decoder)
|
||||
if err != nil {
|
||||
t.Fatalf("expected DecodeConfig with empty type to succeed, but got error: %v", err)
|
||||
t.Fatalf("expected DecodeConfig with empty kind to succeed, but got error: %v", err)
|
||||
}
|
||||
if config == nil {
|
||||
t.Fatal("expected a non-nil config for default type")
|
||||
t.Fatal("expected a non-nil config for default kind")
|
||||
}
|
||||
if config.PromptConfigType() != "custom" {
|
||||
t.Errorf("expected default type to be 'custom', but got %q", config.PromptConfigType())
|
||||
if config.PromptConfigKind() != "custom" {
|
||||
t.Errorf("expected default kind to be 'custom', but got %q", config.PromptConfigKind())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -14,10 +14,8 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
@@ -126,201 +124,272 @@ func (s *StringLevel) Type() string {
|
||||
return "stringLevel"
|
||||
}
|
||||
|
||||
// SourceConfigs is a type used to allow unmarshal of the data source config map
|
||||
type SourceConfigs map[string]sources.SourceConfig
|
||||
|
||||
// validate interface
|
||||
var _ yaml.InterfaceUnmarshalerContext = &SourceConfigs{}
|
||||
|
||||
func (c *SourceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
*c = make(SourceConfigs)
|
||||
// Parse the 'kind' fields for each source
|
||||
var raw map[string]util.DelayedUnmarshaler
|
||||
if err := unmarshal(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, u := range raw {
|
||||
// Unmarshal to a general type that ensure it capture all fields
|
||||
var v map[string]any
|
||||
if err := u.Unmarshal(&v); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal %q: %w", name, err)
|
||||
}
|
||||
|
||||
kind, ok := v["kind"]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'kind' field for source %q", name)
|
||||
}
|
||||
kindStr, ok := kind.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid 'kind' field for source %q (must be a string)", name)
|
||||
}
|
||||
|
||||
yamlDecoder, err := util.NewStrictDecoder(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating YAML decoder for source %q: %w", name, err)
|
||||
}
|
||||
|
||||
sourceConfig, err := sources.DecodeConfig(ctx, kindStr, name, yamlDecoder)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
(*c)[name] = sourceConfig
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthServiceConfigs is a type used to allow unmarshal of the data authService config map
|
||||
type AuthServiceConfigs map[string]auth.AuthServiceConfig
|
||||
|
||||
// validate interface
|
||||
var _ yaml.InterfaceUnmarshalerContext = &AuthServiceConfigs{}
|
||||
|
||||
func (c *AuthServiceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
*c = make(AuthServiceConfigs)
|
||||
// Parse the 'kind' fields for each authService
|
||||
var raw map[string]util.DelayedUnmarshaler
|
||||
if err := unmarshal(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, u := range raw {
|
||||
var v map[string]any
|
||||
if err := u.Unmarshal(&v); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal %q: %w", name, err)
|
||||
}
|
||||
|
||||
kind, ok := v["kind"]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'kind' field for %q", name)
|
||||
}
|
||||
|
||||
dec, err := util.NewStrictDecoder(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating decoder: %w", err)
|
||||
}
|
||||
switch kind {
|
||||
case google.AuthServiceKind:
|
||||
actual := google.Config{Name: name}
|
||||
if err := dec.DecodeContext(ctx, &actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of auth source", kind)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EmbeddingModelConfigs is a type used to allow unmarshal of the embedding model config map
|
||||
type EmbeddingModelConfigs map[string]embeddingmodels.EmbeddingModelConfig
|
||||
|
||||
// validate interface
|
||||
var _ yaml.InterfaceUnmarshalerContext = &EmbeddingModelConfigs{}
|
||||
|
||||
func (c *EmbeddingModelConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
*c = make(EmbeddingModelConfigs)
|
||||
// Parse the 'kind' fields for each embedding model
|
||||
var raw map[string]util.DelayedUnmarshaler
|
||||
if err := unmarshal(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, u := range raw {
|
||||
// Unmarshal to a general type that ensure it capture all fields
|
||||
var v map[string]any
|
||||
if err := u.Unmarshal(&v); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal embedding model %q: %w", name, err)
|
||||
}
|
||||
|
||||
kind, ok := v["kind"]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'kind' field for embedding model %q", name)
|
||||
}
|
||||
|
||||
dec, err := util.NewStrictDecoder(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating decoder: %w", err)
|
||||
}
|
||||
switch kind {
|
||||
case gemini.EmbeddingModelKind:
|
||||
actual := gemini.Config{Name: name}
|
||||
if err := dec.DecodeContext(ctx, &actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of auth source", kind)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToolConfigs is a type used to allow unmarshal of the tool configs
|
||||
type ToolConfigs map[string]tools.ToolConfig
|
||||
|
||||
// validate interface
|
||||
var _ yaml.InterfaceUnmarshalerContext = &ToolConfigs{}
|
||||
|
||||
func (c *ToolConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
*c = make(ToolConfigs)
|
||||
// Parse the 'kind' fields for each source
|
||||
var raw map[string]util.DelayedUnmarshaler
|
||||
if err := unmarshal(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, u := range raw {
|
||||
var v map[string]any
|
||||
if err := u.Unmarshal(&v); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal %q: %w", name, err)
|
||||
}
|
||||
|
||||
// `authRequired` and `useClientOAuth` cannot be specified together
|
||||
if v["authRequired"] != nil && v["useClientOAuth"] == true {
|
||||
return fmt.Errorf("`authRequired` and `useClientOAuth` are mutually exclusive. Choose only one authentication method")
|
||||
}
|
||||
|
||||
// Make `authRequired` an empty list instead of nil for Tool manifest
|
||||
if v["authRequired"] == nil {
|
||||
v["authRequired"] = []string{}
|
||||
}
|
||||
|
||||
kindVal, ok := v["kind"]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'kind' field for tool %q", name)
|
||||
}
|
||||
kindStr, ok := kindVal.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid 'kind' field for tool %q (must be a string)", name)
|
||||
}
|
||||
|
||||
yamlDecoder, err := util.NewStrictDecoder(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating YAML decoder for tool %q: %w", name, err)
|
||||
}
|
||||
|
||||
toolCfg, err := tools.DecodeConfig(ctx, kindStr, name, yamlDecoder)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
(*c)[name] = toolCfg
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToolsetConfigs is a type used to allow unmarshal of the toolset configs
|
||||
type ToolsetConfigs map[string]tools.ToolsetConfig
|
||||
|
||||
// validate interface
|
||||
var _ yaml.InterfaceUnmarshalerContext = &ToolsetConfigs{}
|
||||
|
||||
func (c *ToolsetConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
*c = make(ToolsetConfigs)
|
||||
|
||||
var raw map[string][]string
|
||||
if err := unmarshal(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, toolList := range raw {
|
||||
(*c)[name] = tools.ToolsetConfig{Name: name, ToolNames: toolList}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PromptConfigs is a type used to allow unmarshal of the prompt configs
|
||||
type PromptConfigs map[string]prompts.PromptConfig
|
||||
|
||||
// validate interface
|
||||
var _ yaml.InterfaceUnmarshalerContext = &PromptConfigs{}
|
||||
|
||||
func (c *PromptConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
*c = make(PromptConfigs)
|
||||
var raw map[string]util.DelayedUnmarshaler
|
||||
if err := unmarshal(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, u := range raw {
|
||||
var v map[string]any
|
||||
if err := u.Unmarshal(&v); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal prompt %q: %w", name, err)
|
||||
}
|
||||
|
||||
// Look for the 'kind' field. If it's not present, kindStr will be an
|
||||
// empty string, which prompts.DecodeConfig will correctly default to "custom".
|
||||
var kindStr string
|
||||
if kindVal, ok := v["kind"]; ok {
|
||||
var isString bool
|
||||
kindStr, isString = kindVal.(string)
|
||||
if !isString {
|
||||
return fmt.Errorf("invalid 'kind' field for prompt %q (must be a string)", name)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new, strict decoder for this specific prompt's data.
|
||||
yamlDecoder, err := util.NewStrictDecoder(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating YAML decoder for prompt %q: %w", name, err)
|
||||
}
|
||||
|
||||
// Use the central registry to decode the prompt based on its kind.
|
||||
promptCfg, err := prompts.DecodeConfig(ctx, kindStr, name, yamlDecoder)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
(*c)[name] = promptCfg
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PromptsetConfigs is a type used to allow unmarshal of the PromptsetConfigs configs
|
||||
type PromptsetConfigs map[string]prompts.PromptsetConfig
|
||||
|
||||
func UnmarshalResourceConfig(ctx context.Context, raw []byte) (SourceConfigs, AuthServiceConfigs, EmbeddingModelConfigs, ToolConfigs, ToolsetConfigs, PromptConfigs, error) {
|
||||
// prepare configs map
|
||||
sourceConfigs := make(map[string]sources.SourceConfig)
|
||||
authServiceConfigs := make(AuthServiceConfigs)
|
||||
embeddingModelConfigs := make(EmbeddingModelConfigs)
|
||||
toolConfigs := make(ToolConfigs)
|
||||
toolsetConfigs := make(ToolsetConfigs)
|
||||
promptConfigs := make(PromptConfigs)
|
||||
// promptset configs is not yet supported
|
||||
// validate interface
|
||||
var _ yaml.InterfaceUnmarshalerContext = &PromptsetConfigs{}
|
||||
|
||||
decoder := yaml.NewDecoder(bytes.NewReader(raw))
|
||||
// for loop to unmarshal documents with the `---` separator
|
||||
for {
|
||||
var resource map[string]any
|
||||
if err := decoder.DecodeContext(ctx, &resource); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unable to decode YAML document: %w", err)
|
||||
}
|
||||
var kind, name string
|
||||
var ok bool
|
||||
if kind, ok = resource["kind"].(string); !ok {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("missing 'kind' field or it is not a string")
|
||||
}
|
||||
if name, ok = resource["name"].(string); !ok {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("missing 'name' field or it is not a string")
|
||||
}
|
||||
// remove 'kind' from map for strict unmarshaling
|
||||
delete(resource, "kind")
|
||||
func (c *PromptsetConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
*c = make(PromptsetConfigs)
|
||||
|
||||
switch kind {
|
||||
case "sources":
|
||||
c, err := UnmarshalYAMLSourceConfig(ctx, name, resource)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
|
||||
}
|
||||
sourceConfigs[name] = c
|
||||
case "authServices":
|
||||
c, err := UnmarshalYAMLAuthServiceConfig(ctx, name, resource)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
|
||||
}
|
||||
authServiceConfigs[name] = c
|
||||
case "tools":
|
||||
c, err := UnmarshalYAMLToolConfig(ctx, name, resource)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
|
||||
}
|
||||
toolConfigs[name] = c
|
||||
case "toolsets":
|
||||
c, err := UnmarshalYAMLToolsetConfig(ctx, name, resource)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
|
||||
}
|
||||
toolsetConfigs[name] = c
|
||||
case "embeddingModels":
|
||||
c, err := UnmarshalYAMLEmbeddingModelConfig(ctx, name, resource)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
|
||||
}
|
||||
embeddingModelConfigs[name] = c
|
||||
case "prompts":
|
||||
c, err := UnmarshalYAMLPromptConfig(ctx, name, resource)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
|
||||
}
|
||||
promptConfigs[name] = c
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("invalid kind %s", kind)
|
||||
}
|
||||
}
|
||||
return sourceConfigs, authServiceConfigs, embeddingModelConfigs, toolConfigs, toolsetConfigs, promptConfigs, nil
|
||||
}
|
||||
|
||||
func UnmarshalYAMLSourceConfig(ctx context.Context, name string, r map[string]any) (sources.SourceConfig, error) {
|
||||
resourceType, ok := r["type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing 'type' field or it is not a string")
|
||||
}
|
||||
dec, err := util.NewStrictDecoder(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating decoder: %w", err)
|
||||
}
|
||||
sourceConfig, err := sources.DecodeConfig(ctx, resourceType, name, dec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sourceConfig, nil
|
||||
}
|
||||
|
||||
func UnmarshalYAMLAuthServiceConfig(ctx context.Context, name string, r map[string]any) (auth.AuthServiceConfig, error) {
|
||||
resourceType, ok := r["type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing 'type' field or it is not a string")
|
||||
}
|
||||
if resourceType != google.AuthServiceType {
|
||||
return nil, fmt.Errorf("%s is not a valid type of auth service", resourceType)
|
||||
}
|
||||
dec, err := util.NewStrictDecoder(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating decoder: %s", err)
|
||||
}
|
||||
actual := google.Config{Name: name}
|
||||
if err := dec.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, fmt.Errorf("unable to parse as %s: %w", name, err)
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
func UnmarshalYAMLEmbeddingModelConfig(ctx context.Context, name string, r map[string]any) (embeddingmodels.EmbeddingModelConfig, error) {
|
||||
resourceType, ok := r["type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing 'type' field or it is not a string")
|
||||
}
|
||||
if resourceType != gemini.EmbeddingModelType {
|
||||
return nil, fmt.Errorf("%s is not a valid type of embedding model", resourceType)
|
||||
}
|
||||
dec, err := util.NewStrictDecoder(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating decoder: %s", err)
|
||||
}
|
||||
actual := gemini.Config{Name: name}
|
||||
if err := dec.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, fmt.Errorf("unable to parse as %q: %w", name, err)
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
func UnmarshalYAMLToolConfig(ctx context.Context, name string, r map[string]any) (tools.ToolConfig, error) {
|
||||
resourceType, ok := r["type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing 'type' field or it is not a string")
|
||||
}
|
||||
// `authRequired` and `useClientOAuth` cannot be specified together
|
||||
if r["authRequired"] != nil && r["useClientOAuth"] == true {
|
||||
return nil, fmt.Errorf("`authRequired` and `useClientOAuth` are mutually exclusive. Choose only one authentication method")
|
||||
}
|
||||
// Make `authRequired` an empty list instead of nil for Tool manifest
|
||||
if r["authRequired"] == nil {
|
||||
r["authRequired"] = []string{}
|
||||
}
|
||||
dec, err := util.NewStrictDecoder(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating decoder: %s", err)
|
||||
}
|
||||
toolCfg, err := tools.DecodeConfig(ctx, resourceType, name, dec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toolCfg, nil
|
||||
}
|
||||
|
||||
func UnmarshalYAMLToolsetConfig(ctx context.Context, name string, r map[string]any) (tools.ToolsetConfig, error) {
|
||||
var toolsetConfig tools.ToolsetConfig
|
||||
justTools := map[string]any{"tools": r["tools"]}
|
||||
dec, err := util.NewStrictDecoder(justTools)
|
||||
if err != nil {
|
||||
return toolsetConfig, fmt.Errorf("error creating decoder: %s", err)
|
||||
}
|
||||
var raw map[string][]string
|
||||
if err := dec.DecodeContext(ctx, &raw); err != nil {
|
||||
return toolsetConfig, fmt.Errorf("unable to unmarshal tools: %s", err)
|
||||
}
|
||||
return tools.ToolsetConfig{Name: name, ToolNames: raw["tools"]}, nil
|
||||
}
|
||||
|
||||
func UnmarshalYAMLPromptConfig(ctx context.Context, name string, r map[string]any) (prompts.PromptConfig, error) {
|
||||
// Look for the 'type' field. If it's not present, typeStr will be an
|
||||
// empty string, which prompts.DecodeConfig will correctly default to "custom".
|
||||
var resourceType string
|
||||
if typeVal, ok := r["type"]; ok {
|
||||
var isString bool
|
||||
resourceType, isString = typeVal.(string)
|
||||
if !isString {
|
||||
return nil, fmt.Errorf("invalid 'type' field for prompt %q (must be a string)", name)
|
||||
}
|
||||
}
|
||||
dec, err := util.NewStrictDecoder(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating decoder: %s", err)
|
||||
if err := unmarshal(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Use the central registry to decode the prompt based on its type.
|
||||
promptCfg, err := prompts.DecodeConfig(ctx, resourceType, name, dec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
for name, promptList := range raw {
|
||||
(*c)[name] = prompts.PromptsetConfig{Name: name, PromptNames: promptList}
|
||||
}
|
||||
return promptCfg, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
"example-source": &alloydbpg.Source{
|
||||
Config: alloydbpg.Config{
|
||||
Name: "example-alloydb-source",
|
||||
Type: "alloydb-postgres",
|
||||
Kind: "alloydb-postgres",
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -92,7 +92,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
"example-source2": &alloydbpg.Source{
|
||||
Config: alloydbpg.Config{
|
||||
Name: "example-alloydb-source2",
|
||||
Type: "alloydb-postgres",
|
||||
Kind: "alloydb-postgres",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -82,7 +82,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
childCtx, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/source/init",
|
||||
trace.WithAttributes(attribute.String("source_type", sc.SourceConfigType())),
|
||||
trace.WithAttributes(attribute.String("source_kind", sc.SourceConfigKind())),
|
||||
trace.WithAttributes(attribute.String("source_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
@@ -110,7 +110,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
_, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/auth/init",
|
||||
trace.WithAttributes(attribute.String("auth_type", sc.AuthServiceConfigType())),
|
||||
trace.WithAttributes(attribute.String("auth_kind", sc.AuthServiceConfigKind())),
|
||||
trace.WithAttributes(attribute.String("auth_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
@@ -138,7 +138,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
_, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/embeddingmodel/init",
|
||||
trace.WithAttributes(attribute.String("model_type", ec.EmbeddingModelConfigType())),
|
||||
trace.WithAttributes(attribute.String("model_kind", ec.EmbeddingModelConfigKind())),
|
||||
trace.WithAttributes(attribute.String("model_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
@@ -166,7 +166,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
_, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/tool/init",
|
||||
trace.WithAttributes(attribute.String("tool_type", tc.ToolConfigType())),
|
||||
trace.WithAttributes(attribute.String("tool_kind", tc.ToolConfigKind())),
|
||||
trace.WithAttributes(attribute.String("tool_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
@@ -235,7 +235,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
_, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/prompt/init",
|
||||
trace.WithAttributes(attribute.String("prompt_type", pc.PromptConfigType())),
|
||||
trace.WithAttributes(attribute.String("prompt_kind", pc.PromptConfigKind())),
|
||||
trace.WithAttributes(attribute.String("prompt_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
|
||||
@@ -141,7 +141,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
"example-source": &alloydbpg.Source{
|
||||
Config: alloydbpg.Config{
|
||||
Name: "example-alloydb-source",
|
||||
Type: "alloydb-postgres",
|
||||
Kind: "alloydb-postgres",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -32,14 +32,14 @@ import (
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
const SourceType string = "alloydb-admin"
|
||||
const SourceKind string = "alloydb-admin"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,13 +53,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
DefaultProject string `yaml:"defaultProject"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -106,8 +106,8 @@ type Source struct {
|
||||
Service *alloydbrestapi.Service
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
package alloydbadmin_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -34,14 +34,14 @@ func TestParseFromYamlAlloyDBAdmin(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-alloydb-admin-instance
|
||||
type: alloydb-admin
|
||||
sources:
|
||||
my-alloydb-admin-instance:
|
||||
kind: alloydb-admin
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-alloydb-admin-instance": alloydbadmin.Config{
|
||||
Name: "my-alloydb-admin-instance",
|
||||
Type: alloydbadmin.SourceType,
|
||||
Kind: alloydbadmin.SourceKind,
|
||||
UseClientOAuth: false,
|
||||
},
|
||||
},
|
||||
@@ -49,15 +49,15 @@ func TestParseFromYamlAlloyDBAdmin(t *testing.T) {
|
||||
{
|
||||
desc: "use client auth example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-alloydb-admin-instance
|
||||
type: alloydb-admin
|
||||
useClientOAuth: true
|
||||
sources:
|
||||
my-alloydb-admin-instance:
|
||||
kind: alloydb-admin
|
||||
useClientOAuth: true
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-alloydb-admin-instance": alloydbadmin.Config{
|
||||
Name: "my-alloydb-admin-instance",
|
||||
Type: alloydbadmin.SourceType,
|
||||
Kind: alloydbadmin.SourceKind,
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
},
|
||||
@@ -65,13 +65,16 @@ func TestParseFromYamlAlloyDBAdmin(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -86,27 +89,30 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-alloydb-admin-instance
|
||||
type: alloydb-admin
|
||||
project: test-project
|
||||
sources:
|
||||
my-alloydb-admin-instance:
|
||||
kind: alloydb-admin
|
||||
project: test-project
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-alloydb-admin-instance\" as \"alloydb-admin\": [2:1] unknown field \"project\"\n 1 | name: my-alloydb-admin-instance\n> 2 | project: test-project\n ^\n 3 | type: alloydb-admin",
|
||||
err: "unable to parse source \"my-alloydb-admin-instance\" as \"alloydb-admin\": [2:1] unknown field \"project\"\n 1 | kind: alloydb-admin\n> 2 | project: test-project\n ^\n",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-alloydb-admin-instance
|
||||
useClientOAuth: true
|
||||
sources:
|
||||
my-alloydb-admin-instance:
|
||||
useClientOAuth: true
|
||||
`,
|
||||
err: "error unmarshaling sources: missing 'type' field or it is not a string",
|
||||
err: "missing 'kind' field for source \"my-alloydb-admin-instance\"",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -29,14 +29,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "alloydb-postgres"
|
||||
const SourceKind string = "alloydb-postgres"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Region string `yaml:"region" validate:"required"`
|
||||
Cluster string `yaml:"cluster" validate:"required"`
|
||||
@@ -61,8 +61,8 @@ type Config struct {
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -90,8 +90,8 @@ type Source struct {
|
||||
Pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -183,7 +183,7 @@ func getConnectionConfig(ctx context.Context, user, pass, dbname string) (string
|
||||
|
||||
func initAlloyDBPgConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, cluster, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
dsn, useIAM, err := getConnectionConfig(ctx, user, pass, dbname)
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
package alloydbpg_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -34,21 +34,21 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: alloydb-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
cluster: my-cluster
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: alloydb-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
cluster: my-cluster
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-pg-instance": alloydbpg.Config{
|
||||
Name: "my-pg-instance",
|
||||
Type: alloydbpg.SourceType,
|
||||
Kind: alloydbpg.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Cluster: "my-cluster",
|
||||
@@ -63,22 +63,22 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
|
||||
{
|
||||
desc: "public ipType",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: alloydb-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
cluster: my-cluster
|
||||
instance: my-instance
|
||||
ipType: Public
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: alloydb-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
cluster: my-cluster
|
||||
instance: my-instance
|
||||
ipType: Public
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-pg-instance": alloydbpg.Config{
|
||||
Name: "my-pg-instance",
|
||||
Type: alloydbpg.SourceType,
|
||||
Kind: alloydbpg.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Cluster: "my-cluster",
|
||||
@@ -93,22 +93,22 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
|
||||
{
|
||||
desc: "private ipType",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: alloydb-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
cluster: my-cluster
|
||||
instance: my-instance
|
||||
ipType: private
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: alloydb-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
cluster: my-cluster
|
||||
instance: my-instance
|
||||
ipType: private
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-pg-instance": alloydbpg.Config{
|
||||
Name: "my-pg-instance",
|
||||
Type: alloydbpg.SourceType,
|
||||
Kind: alloydbpg.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Cluster: "my-cluster",
|
||||
@@ -123,13 +123,16 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -144,56 +147,60 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "invalid ipType",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: alloydb-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
cluster: my-cluster
|
||||
instance: my-instance
|
||||
ipType: fail
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: alloydb-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
cluster: my-cluster
|
||||
instance: my-instance
|
||||
ipType: fail
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"alloydb-postgres\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
|
||||
err: "unable to parse source \"my-pg-instance\" as \"alloydb-postgres\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
|
||||
},
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: alloydb-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
cluster: my-cluster
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: alloydb-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
cluster: my-cluster
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"alloydb-postgres\": [3:1] unknown field \"foo\"\n 1 | cluster: my-cluster\n 2 | database: my_db\n> 3 | foo: bar\n ^\n 4 | instance: my-instance\n 5 | name: my-pg-instance\n 6 | password: my_pass\n 7 | ",
|
||||
err: "unable to parse source \"my-pg-instance\" as \"alloydb-postgres\": [3:1] unknown field \"foo\"\n 1 | cluster: my-cluster\n 2 | database: my_db\n> 3 | foo: bar\n ^\n 4 | instance: my-instance\n 5 | kind: alloydb-postgres\n 6 | password: my_pass\n 7 | ",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: alloydb-postgres
|
||||
region: my-region
|
||||
cluster: my-cluster
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: alloydb-postgres
|
||||
region: my-region
|
||||
cluster: my-cluster
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"alloydb-postgres\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-pg-instance\" as \"alloydb-postgres\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ import (
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
const SourceType string = "bigquery"
|
||||
const SourceKind string = "bigquery"
|
||||
|
||||
// CloudPlatformScope is a broad scope for Google Cloud Platform services.
|
||||
const CloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
|
||||
@@ -65,8 +65,8 @@ type BigQuerySessionProvider func(ctx context.Context) (*Session, error)
|
||||
type DataplexClientCreator func(tokenString string) (*dataplexapi.CatalogClient, error)
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,7 +81,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
type Config struct {
|
||||
// BigQuery configs
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Location string `yaml:"location"`
|
||||
WriteMode string `yaml:"writeMode"`
|
||||
@@ -119,9 +119,9 @@ func (s *StringOrStringSlice) UnmarshalYAML(unmarshal func(any) error) error {
|
||||
return fmt.Errorf("cannot unmarshal %T into StringOrStringSlice", v)
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
// Returns BigQuery source type
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
// Returns BigQuery source kind
|
||||
return SourceKind
|
||||
}
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
if r.WriteMode == "" {
|
||||
@@ -302,9 +302,9 @@ type Session struct {
|
||||
LastUsed time.Time
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
// Returns BigQuery Google SQL source type
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
// Returns BigQuery Google SQL source kind
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -665,7 +665,7 @@ func initBigQueryConnection(
|
||||
impersonateServiceAccount string,
|
||||
scopes []string,
|
||||
) (*bigqueryapi.Client, *bigqueryrestapi.Service, oauth2.TokenSource, error) {
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
@@ -741,7 +741,7 @@ func initBigQueryConnectionWithOAuthToken(
|
||||
tokenString string,
|
||||
wantRestService bool,
|
||||
) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
// Construct token source
|
||||
token := &oauth2.Token{
|
||||
@@ -801,7 +801,7 @@ func initDataplexConnection(
|
||||
var clientCreator DataplexClientCreator
|
||||
var err error
|
||||
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
|
||||
@@ -15,18 +15,18 @@
|
||||
package bigquery_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"go.opentelemetry.io/otel/trace/noop"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/trace/noop"
|
||||
)
|
||||
|
||||
func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
@@ -38,15 +38,15 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: bigquery
|
||||
project: my-project
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Type: bigquery.SourceType,
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "",
|
||||
WriteMode: "",
|
||||
@@ -56,17 +56,17 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
{
|
||||
desc: "all fields specified",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: bigquery
|
||||
project: my-project
|
||||
location: asia
|
||||
writeMode: blocked
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: asia
|
||||
writeMode: blocked
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Type: bigquery.SourceType,
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "asia",
|
||||
WriteMode: "blocked",
|
||||
@@ -77,17 +77,17 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
{
|
||||
desc: "use client auth example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
useClientOAuth: true
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
useClientOAuth: true
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Type: bigquery.SourceType,
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
UseClientOAuth: true,
|
||||
@@ -97,18 +97,18 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
{
|
||||
desc: "with allowed datasets example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
allowedDatasets:
|
||||
- my_dataset
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
allowedDatasets:
|
||||
- my_dataset
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Type: bigquery.SourceType,
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
AllowedDatasets: []string{"my_dataset"},
|
||||
@@ -118,17 +118,17 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
{
|
||||
desc: "with service account impersonation example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
impersonateServiceAccount: service-account@my-project.iam.gserviceaccount.com
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
impersonateServiceAccount: service-account@my-project.iam.gserviceaccount.com
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Type: bigquery.SourceType,
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
ImpersonateServiceAccount: "service-account@my-project.iam.gserviceaccount.com",
|
||||
@@ -138,19 +138,19 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
{
|
||||
desc: "with custom scopes example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
scopes:
|
||||
- https://www.googleapis.com/auth/bigquery
|
||||
- https://www.googleapis.com/auth/cloud-platform
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
scopes:
|
||||
- https://www.googleapis.com/auth/bigquery
|
||||
- https://www.googleapis.com/auth/cloud-platform
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Type: bigquery.SourceType,
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
Scopes: []string{"https://www.googleapis.com/auth/bigquery", "https://www.googleapis.com/auth/cloud-platform"},
|
||||
@@ -160,17 +160,17 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
{
|
||||
desc: "with max query result rows example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
maxQueryResultRows: 10
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
maxQueryResultRows: 10
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Type: bigquery.SourceType,
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
MaxQueryResultRows: 10,
|
||||
@@ -180,15 +180,20 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got); diff != "" {
|
||||
t.Fatalf("incorrect parse (-want +got):\n%s", diff)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFailParseFromYaml(t *testing.T) {
|
||||
@@ -200,29 +205,33 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
foo: bar
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"bigquery\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | location: us\n 3 | name: my-instance\n 4 | project: my-project\n 5 | ",
|
||||
err: "unable to parse source \"my-instance\" as \"bigquery\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | kind: bigquery\n 3 | location: us\n 4 | project: my-project",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: bigquery
|
||||
location: us
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
location: us
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"bigquery\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-instance\" as \"bigquery\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
@@ -251,7 +260,7 @@ func TestInitialize_MaxQueryResultRows(t *testing.T) {
|
||||
desc: "default value",
|
||||
cfg: bigquery.Config{
|
||||
Name: "test-default",
|
||||
Type: bigquery.SourceType,
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "test-project",
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
@@ -261,7 +270,7 @@ func TestInitialize_MaxQueryResultRows(t *testing.T) {
|
||||
desc: "configured value",
|
||||
cfg: bigquery.Config{
|
||||
Name: "test-configured",
|
||||
Type: bigquery.SourceType,
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "test-project",
|
||||
UseClientOAuth: true,
|
||||
MaxQueryResultRows: 100,
|
||||
|
||||
@@ -27,14 +27,14 @@ import (
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
const SourceType string = "bigtable"
|
||||
const SourceKind string = "bigtable"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,13 +48,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Instance string `yaml:"instance" validate:"required"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -77,8 +77,8 @@ type Source struct {
|
||||
Client *bigtable.Client
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -179,7 +179,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, configParam param
|
||||
|
||||
func initBigtableClient(ctx context.Context, tracer trace.Tracer, name, project, instance string) (*bigtable.Client, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
// Set up Bigtable data operations client.
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
package bigtable_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -34,16 +34,16 @@ func TestParseFromYamlBigtableDb(t *testing.T) {
|
||||
{
|
||||
desc: "can configure with a bigtable table",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-bigtable-instance
|
||||
type: bigtable
|
||||
project: my-project
|
||||
instance: my-instance
|
||||
sources:
|
||||
my-bigtable-instance:
|
||||
kind: bigtable
|
||||
project: my-project
|
||||
instance: my-instance
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-bigtable-instance": bigtable.Config{
|
||||
Name: "my-bigtable-instance",
|
||||
Type: bigtable.SourceType,
|
||||
Kind: bigtable.SourceKind,
|
||||
Project: "my-project",
|
||||
Instance: "my-instance",
|
||||
},
|
||||
@@ -52,12 +52,16 @@ func TestParseFromYamlBigtableDb(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -73,29 +77,33 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-bigtable-instance
|
||||
type: bigtable
|
||||
project: my-project
|
||||
instance: my-instance
|
||||
foo: bar
|
||||
sources:
|
||||
my-bigtable-instance:
|
||||
kind: bigtable
|
||||
project: my-project
|
||||
instance: my-instance
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-bigtable-instance\" as \"bigtable\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | instance: my-instance\n 3 | name: my-bigtable-instance\n 4 | project: my-project\n 5 | ",
|
||||
err: "unable to parse source \"my-bigtable-instance\" as \"bigtable\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | instance: my-instance\n 3 | kind: bigtable\n 4 | project: my-project",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-bigtable-instance
|
||||
type: bigtable
|
||||
project: my-project
|
||||
sources:
|
||||
my-bigtable-instance:
|
||||
kind: bigtable
|
||||
project: my-project
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-bigtable-instance\" as \"bigtable\": Key: 'Config.Instance' Error:Field validation for 'Instance' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-bigtable-instance\" as \"bigtable\": Key: 'Config.Instance' Error:Field validation for 'Instance' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -25,11 +25,11 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "cassandra"
|
||||
const SourceKind string = "cassandra"
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Hosts []string `yaml:"hosts" validate:"required"`
|
||||
Keyspace string `yaml:"keyspace"`
|
||||
ProtoVersion int `yaml:"protoVersion"`
|
||||
@@ -68,9 +68,9 @@ func (c Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// SourceConfigType implements sources.SourceConfig.
|
||||
func (c Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
// SourceConfigKind implements sources.SourceConfig.
|
||||
func (c Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
var _ sources.SourceConfig = Config{}
|
||||
@@ -89,9 +89,9 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
// SourceType implements sources.Source.
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
// SourceKind implements sources.Source.
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) RunSQL(ctx context.Context, statement string, params parameters.ParamValues) (any, error) {
|
||||
@@ -120,7 +120,7 @@ var _ sources.Source = &Source{}
|
||||
|
||||
func initCassandraSession(ctx context.Context, tracer trace.Tracer, c Config) (*gocql.Session, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, c.Name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, c.Name)
|
||||
defer span.End()
|
||||
|
||||
// Validate authentication configuration
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
package cassandra_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cassandra"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -34,17 +33,17 @@ func TestParseFromYamlCassandra(t *testing.T) {
|
||||
{
|
||||
desc: "basic example (without optional fields)",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-cassandra-instance
|
||||
type: cassandra
|
||||
hosts:
|
||||
- "my-host1"
|
||||
- "my-host2"
|
||||
sources:
|
||||
my-cassandra-instance:
|
||||
kind: cassandra
|
||||
hosts:
|
||||
- "my-host1"
|
||||
- "my-host2"
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-cassandra-instance": cassandra.Config{
|
||||
Name: "my-cassandra-instance",
|
||||
Type: cassandra.SourceType,
|
||||
Kind: cassandra.SourceKind,
|
||||
Hosts: []string{"my-host1", "my-host2"},
|
||||
Username: "",
|
||||
Password: "",
|
||||
@@ -60,25 +59,25 @@ func TestParseFromYamlCassandra(t *testing.T) {
|
||||
{
|
||||
desc: "with optional fields",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-cassandra-instance
|
||||
type: cassandra
|
||||
hosts:
|
||||
- "my-host1"
|
||||
- "my-host2"
|
||||
username: "user"
|
||||
password: "pass"
|
||||
keyspace: "example_keyspace"
|
||||
protoVersion: 4
|
||||
caPath: "path/to/ca.crt"
|
||||
certPath: "path/to/cert"
|
||||
keyPath: "path/to/key"
|
||||
enableHostVerification: true
|
||||
sources:
|
||||
my-cassandra-instance:
|
||||
kind: cassandra
|
||||
hosts:
|
||||
- "my-host1"
|
||||
- "my-host2"
|
||||
username: "user"
|
||||
password: "pass"
|
||||
keyspace: "example_keyspace"
|
||||
protoVersion: 4
|
||||
caPath: "path/to/ca.crt"
|
||||
certPath: "path/to/cert"
|
||||
keyPath: "path/to/key"
|
||||
enableHostVerification: true
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-cassandra-instance": cassandra.Config{
|
||||
Name: "my-cassandra-instance",
|
||||
Type: cassandra.SourceType,
|
||||
Kind: cassandra.SourceKind,
|
||||
Hosts: []string{"my-host1", "my-host2"},
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
@@ -94,12 +93,16 @@ func TestParseFromYamlCassandra(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -115,29 +118,33 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-cassandra-instance
|
||||
type: cassandra
|
||||
hosts:
|
||||
- "my-host"
|
||||
foo: bar
|
||||
sources:
|
||||
my-cassandra-instance:
|
||||
kind: cassandra
|
||||
hosts:
|
||||
- "my-host"
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-cassandra-instance\" as \"cassandra\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | hosts:\n 3 | - my-host\n 4 | name: my-cassandra-instance\n 5 | ",
|
||||
err: "unable to parse source \"my-cassandra-instance\" as \"cassandra\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | hosts:\n 3 | - my-host\n 4 | kind: cassandra",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-cassandra-instance
|
||||
type: cassandra
|
||||
sources:
|
||||
my-cassandra-instance:
|
||||
kind: cassandra
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-cassandra-instance\" as \"cassandra\": Key: 'Config.Hosts' Error:Field validation for 'Hosts' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-cassandra-instance\" as \"cassandra\": Key: 'Config.Hosts' Error:Field validation for 'Hosts' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -28,14 +28,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "clickhouse"
|
||||
const SourceKind string = "clickhouse"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
@@ -59,8 +59,8 @@ type Config struct {
|
||||
Secure bool `yaml:"secure"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -88,8 +88,8 @@ type Source struct {
|
||||
Pool *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -174,7 +174,7 @@ func validateConfig(protocol string) error {
|
||||
|
||||
func initClickHouseConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, protocol string, secure bool) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
if protocol == "" {
|
||||
|
||||
@@ -21,113 +21,137 @@ import (
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"go.opentelemetry.io/otel"
|
||||
)
|
||||
|
||||
func TestParseFromYamlClickhouse(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "all fields specified",
|
||||
in: `
|
||||
kind: sources
|
||||
name: test-clickhouse
|
||||
type: clickhouse
|
||||
host: localhost
|
||||
port: "8443"
|
||||
user: default
|
||||
password: "mypass"
|
||||
database: mydb
|
||||
protocol: https
|
||||
secure: true
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"test-clickhouse": Config{
|
||||
Name: "test-clickhouse",
|
||||
Type: "clickhouse",
|
||||
Host: "localhost",
|
||||
Port: "8443",
|
||||
User: "default",
|
||||
Password: "mypass",
|
||||
Database: "mydb",
|
||||
Protocol: "https",
|
||||
Secure: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "minimal configuration with defaults",
|
||||
in: `
|
||||
kind: sources
|
||||
name: minimal-clickhouse
|
||||
type: clickhouse
|
||||
host: 127.0.0.1
|
||||
port: "8123"
|
||||
user: testuser
|
||||
database: testdb
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"minimal-clickhouse": Config{
|
||||
Name: "minimal-clickhouse",
|
||||
Type: "clickhouse",
|
||||
Host: "127.0.0.1",
|
||||
Port: "8123",
|
||||
User: "testuser",
|
||||
Password: "",
|
||||
Database: "testdb",
|
||||
Protocol: "",
|
||||
Secure: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
}
|
||||
})
|
||||
func TestConfigSourceConfigKind(t *testing.T) {
|
||||
config := Config{}
|
||||
if config.SourceConfigKind() != SourceKind {
|
||||
t.Errorf("Expected %s, got %s", SourceKind, config.SourceConfigKind())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailParseFromYaml(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
func TestNewConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
yaml string
|
||||
expected Config
|
||||
}{
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: test-clickhouse
|
||||
type: clickhouse
|
||||
host: localhost
|
||||
foo: bar
|
||||
name: "all fields specified",
|
||||
yaml: `
|
||||
name: test-clickhouse
|
||||
kind: clickhouse
|
||||
host: localhost
|
||||
port: "8443"
|
||||
user: default
|
||||
password: "mypass"
|
||||
database: mydb
|
||||
protocol: https
|
||||
secure: true
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"test-clickhouse\" as \"clickhouse\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | host: localhost\n 3 | name: test-clickhouse\n 4 | type: clickhouse",
|
||||
expected: Config{
|
||||
Name: "test-clickhouse",
|
||||
Kind: "clickhouse",
|
||||
Host: "localhost",
|
||||
Port: "8443",
|
||||
User: "default",
|
||||
Password: "mypass",
|
||||
Database: "mydb",
|
||||
Protocol: "https",
|
||||
Secure: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "minimal configuration with defaults",
|
||||
yaml: `
|
||||
name: minimal-clickhouse
|
||||
kind: clickhouse
|
||||
host: 127.0.0.1
|
||||
port: "8123"
|
||||
user: testuser
|
||||
database: testdb
|
||||
`,
|
||||
expected: Config{
|
||||
Name: "minimal-clickhouse",
|
||||
Kind: "clickhouse",
|
||||
Host: "127.0.0.1",
|
||||
Port: "8123",
|
||||
User: "testuser",
|
||||
Password: "",
|
||||
Database: "testdb",
|
||||
Protocol: "",
|
||||
Secure: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "http protocol",
|
||||
yaml: `
|
||||
name: http-clickhouse
|
||||
kind: clickhouse
|
||||
host: clickhouse.example.com
|
||||
port: "8123"
|
||||
user: analytics
|
||||
password: "securepass"
|
||||
database: analytics_db
|
||||
protocol: http
|
||||
secure: false
|
||||
`,
|
||||
expected: Config{
|
||||
Name: "http-clickhouse",
|
||||
Kind: "clickhouse",
|
||||
Host: "clickhouse.example.com",
|
||||
Port: "8123",
|
||||
User: "analytics",
|
||||
Password: "securepass",
|
||||
Database: "analytics_db",
|
||||
Protocol: "http",
|
||||
Secure: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "https with secure connection",
|
||||
yaml: `
|
||||
name: secure-clickhouse
|
||||
kind: clickhouse
|
||||
host: secure.clickhouse.io
|
||||
port: "8443"
|
||||
user: secureuser
|
||||
password: "verysecure"
|
||||
database: production
|
||||
protocol: https
|
||||
secure: true
|
||||
`,
|
||||
expected: Config{
|
||||
Name: "secure-clickhouse",
|
||||
Kind: "clickhouse",
|
||||
Host: "secure.clickhouse.io",
|
||||
Port: "8443",
|
||||
User: "secureuser",
|
||||
Password: "verysecure",
|
||||
Database: "production",
|
||||
Protocol: "https",
|
||||
Secure: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
decoder := yaml.NewDecoder(strings.NewReader(string(testutils.FormatYaml(tt.yaml))))
|
||||
config, err := newConfig(context.Background(), tt.expected.Name, decoder)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create config: %v", err)
|
||||
}
|
||||
errStr := err.Error()
|
||||
if errStr != tc.err {
|
||||
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
|
||||
|
||||
clickhouseConfig, ok := config.(Config)
|
||||
if !ok {
|
||||
t.Fatalf("Expected Config type, got %T", config)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expected, clickhouseConfig); diff != "" {
|
||||
t.Errorf("Config mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -143,11 +167,19 @@ func TestNewConfigInvalidYAML(t *testing.T) {
|
||||
name: "invalid yaml syntax",
|
||||
yaml: `
|
||||
name: test-clickhouse
|
||||
type: clickhouse
|
||||
kind: clickhouse
|
||||
host: [invalid
|
||||
`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "missing required fields",
|
||||
yaml: `
|
||||
name: test-clickhouse
|
||||
kind: clickhouse
|
||||
`,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -164,10 +196,10 @@ func TestNewConfigInvalidYAML(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSource_SourceType(t *testing.T) {
|
||||
func TestSource_SourceKind(t *testing.T) {
|
||||
source := &Source{}
|
||||
if source.SourceType() != SourceType {
|
||||
t.Errorf("Expected %s, got %s", SourceType, source.SourceType())
|
||||
if source.SourceKind() != SourceKind {
|
||||
t.Errorf("Expected %s, got %s", SourceKind, source.SourceKind())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -29,15 +29,15 @@ import (
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
const SourceType string = "cloud-gemini-data-analytics"
|
||||
const SourceKind string = "cloud-gemini-data-analytics"
|
||||
const Endpoint string = "https://geminidataanalytics.googleapis.com"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,13 +51,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
ProjectID string `yaml:"projectId" validate:"required"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
// Initialize initializes a Gemini Data Analytics Source instance.
|
||||
@@ -102,8 +102,8 @@ type Source struct {
|
||||
userAgent string
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -38,15 +39,15 @@ func TestParseFromYamlCloudGDA(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-gda-instance
|
||||
type: cloud-gemini-data-analytics
|
||||
projectId: test-project-id
|
||||
`,
|
||||
sources:
|
||||
my-gda-instance:
|
||||
kind: cloud-gemini-data-analytics
|
||||
projectId: test-project-id
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-gda-instance": cloudgda.Config{
|
||||
Name: "my-gda-instance",
|
||||
Type: cloudgda.SourceType,
|
||||
Kind: cloudgda.SourceKind,
|
||||
ProjectID: "test-project-id",
|
||||
UseClientOAuth: false,
|
||||
},
|
||||
@@ -55,16 +56,16 @@ func TestParseFromYamlCloudGDA(t *testing.T) {
|
||||
{
|
||||
desc: "use client auth example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-gda-instance
|
||||
type: cloud-gemini-data-analytics
|
||||
projectId: another-project
|
||||
useClientOAuth: true
|
||||
sources:
|
||||
my-gda-instance:
|
||||
kind: cloud-gemini-data-analytics
|
||||
projectId: another-project
|
||||
useClientOAuth: true
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-gda-instance": cloudgda.Config{
|
||||
Name: "my-gda-instance",
|
||||
Type: cloudgda.SourceType,
|
||||
Kind: cloudgda.SourceKind,
|
||||
ProjectID: "another-project",
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
@@ -75,12 +76,16 @@ func TestParseFromYamlCloudGDA(t *testing.T) {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -96,18 +101,22 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "missing projectId",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-gda-instance
|
||||
type: cloud-gemini-data-analytics
|
||||
sources:
|
||||
my-gda-instance:
|
||||
kind: cloud-gemini-data-analytics
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-gda-instance\" as \"cloud-gemini-data-analytics\": Key: 'Config.ProjectID' Error:Field validation for 'ProjectID' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-gda-instance\" as \"cloud-gemini-data-analytics\": Key: 'Config.ProjectID' Error:Field validation for 'ProjectID' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
@@ -144,12 +153,12 @@ func TestInitialize(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
desc: "initialize with ADC",
|
||||
cfg: cloudgda.Config{Name: "test-gda", Type: cloudgda.SourceType, ProjectID: "test-proj"},
|
||||
cfg: cloudgda.Config{Name: "test-gda", Kind: cloudgda.SourceKind, ProjectID: "test-proj"},
|
||||
wantClientOAuth: false,
|
||||
},
|
||||
{
|
||||
desc: "initialize with client OAuth",
|
||||
cfg: cloudgda.Config{Name: "test-gda-oauth", Type: cloudgda.SourceType, ProjectID: "test-proj", UseClientOAuth: true},
|
||||
cfg: cloudgda.Config{Name: "test-gda-oauth", Kind: cloudgda.SourceKind, ProjectID: "test-proj", UseClientOAuth: true},
|
||||
wantClientOAuth: true,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ import (
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
const SourceType string = "cloud-healthcare"
|
||||
const SourceKind string = "cloud-healthcare"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
@@ -42,8 +42,8 @@ var _ sources.SourceConfig = Config{}
|
||||
type HealthcareServiceCreator func(tokenString string) (*healthcare.Service, error)
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
type Config struct {
|
||||
// Healthcare configs
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Region string `yaml:"region" validate:"required"`
|
||||
Dataset string `yaml:"dataset" validate:"required"`
|
||||
@@ -67,8 +67,8 @@ type Config struct {
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
}
|
||||
|
||||
func (c Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (c Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (c Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -144,7 +144,7 @@ func newHealthcareServiceCreator(ctx context.Context, tracer trace.Tracer, name
|
||||
}
|
||||
|
||||
func initHealthcareConnectionWithOAuthToken(ctx context.Context, tracer trace.Tracer, name string, userAgent string, tokenString string) (*healthcare.Service, error) {
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
// Construct token source
|
||||
token := &oauth2.Token{
|
||||
@@ -162,7 +162,7 @@ func initHealthcareConnectionWithOAuthToken(ctx context.Context, tracer trace.Tr
|
||||
}
|
||||
|
||||
func initHealthcareConnection(ctx context.Context, tracer trace.Tracer, name string) (*healthcare.Service, oauth2.TokenSource, error) {
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
cred, err := google.FindDefaultCredentials(ctx, healthcare.CloudHealthcareScope)
|
||||
@@ -194,8 +194,8 @@ type Source struct {
|
||||
allowedDICOMStores map[string]struct{}
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -517,14 +517,14 @@ func (s *Source) RetrieveRenderedDICOMInstance(storeID, study, series, sop strin
|
||||
return base64String, nil
|
||||
}
|
||||
|
||||
func (s *Source) SearchDICOM(toolType, storeID, dicomWebPath, tokenStr string, opts []googleapi.CallOption) (any, error) {
|
||||
func (s *Source) SearchDICOM(toolKind, storeID, dicomWebPath, tokenStr string, opts []googleapi.CallOption) (any, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
var resp *http.Response
|
||||
switch toolType {
|
||||
switch toolKind {
|
||||
case "cloud-healthcare-search-dicom-instances":
|
||||
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...)
|
||||
case "cloud-healthcare-search-dicom-series":
|
||||
@@ -532,7 +532,7 @@ func (s *Source) SearchDICOM(toolType, storeID, dicomWebPath, tokenStr string, o
|
||||
case "cloud-healthcare-search-dicom-studies":
|
||||
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForStudies(name, dicomWebPath).Do(opts...)
|
||||
default:
|
||||
return nil, fmt.Errorf("incompatible tool type: %s", toolType)
|
||||
return nil, fmt.Errorf("incompatible tool kind: %s", toolKind)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search dicom series: %w", err)
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
package cloudhealthcare_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -34,17 +33,17 @@ func TestParseFromYamlCloudHealthcare(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: cloud-healthcare
|
||||
project: my-project
|
||||
region: us-central1
|
||||
dataset: my-dataset
|
||||
sources:
|
||||
my-instance:
|
||||
kind: cloud-healthcare
|
||||
project: my-project
|
||||
region: us-central1
|
||||
dataset: my-dataset
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": cloudhealthcare.Config{
|
||||
Name: "my-instance",
|
||||
Type: cloudhealthcare.SourceType,
|
||||
Kind: cloudhealthcare.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "us-central1",
|
||||
Dataset: "my-dataset",
|
||||
@@ -55,18 +54,18 @@ func TestParseFromYamlCloudHealthcare(t *testing.T) {
|
||||
{
|
||||
desc: "use client auth example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: cloud-healthcare
|
||||
project: my-project
|
||||
region: us
|
||||
dataset: my-dataset
|
||||
useClientOAuth: true
|
||||
sources:
|
||||
my-instance:
|
||||
kind: cloud-healthcare
|
||||
project: my-project
|
||||
region: us
|
||||
dataset: my-dataset
|
||||
useClientOAuth: true
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": cloudhealthcare.Config{
|
||||
Name: "my-instance",
|
||||
Type: cloudhealthcare.SourceType,
|
||||
Kind: cloudhealthcare.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "us",
|
||||
Dataset: "my-dataset",
|
||||
@@ -77,22 +76,22 @@ func TestParseFromYamlCloudHealthcare(t *testing.T) {
|
||||
{
|
||||
desc: "with allowed stores example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: cloud-healthcare
|
||||
project: my-project
|
||||
region: us
|
||||
dataset: my-dataset
|
||||
allowedFhirStores:
|
||||
- my-fhir-store
|
||||
allowedDicomStores:
|
||||
- my-dicom-store1
|
||||
- my-dicom-store2
|
||||
sources:
|
||||
my-instance:
|
||||
kind: cloud-healthcare
|
||||
project: my-project
|
||||
region: us
|
||||
dataset: my-dataset
|
||||
allowedFhirStores:
|
||||
- my-fhir-store
|
||||
allowedDicomStores:
|
||||
- my-dicom-store1
|
||||
- my-dicom-store2
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": cloudhealthcare.Config{
|
||||
Name: "my-instance",
|
||||
Type: cloudhealthcare.SourceType,
|
||||
Kind: cloudhealthcare.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "us",
|
||||
Dataset: "my-dataset",
|
||||
@@ -104,12 +103,16 @@ func TestParseFromYamlCloudHealthcare(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -124,31 +127,35 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: cloud-healthcare
|
||||
project: my-project
|
||||
region: us-central1
|
||||
dataset: my-dataset
|
||||
foo: bar
|
||||
sources:
|
||||
my-instance:
|
||||
kind: cloud-healthcare
|
||||
project: my-project
|
||||
region: us-central1
|
||||
dataset: my-dataset
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"cloud-healthcare\": [2:1] unknown field \"foo\"\n 1 | dataset: my-dataset\n> 2 | foo: bar\n ^\n 3 | name: my-instance\n 4 | project: my-project\n 5 | region: us-central1\n 6 | ",
|
||||
err: "unable to parse source \"my-instance\" as \"cloud-healthcare\": [2:1] unknown field \"foo\"\n 1 | dataset: my-dataset\n> 2 | foo: bar\n ^\n 3 | kind: cloud-healthcare\n 4 | project: my-project\n 5 | region: us-central1",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: cloud-healthcare
|
||||
project: my-project
|
||||
region: us-central1
|
||||
sources:
|
||||
my-instance:
|
||||
kind: cloud-healthcare
|
||||
project: my-project
|
||||
region: us-central1
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"cloud-healthcare\": Key: 'Config.Dataset' Error:Field validation for 'Dataset' failed on the 'required' tag",
|
||||
err: `unable to parse source "my-instance" as "cloud-healthcare": Key: 'Config.Dataset' Error:Field validation for 'Dataset' failed on the 'required' tag`,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -29,14 +29,14 @@ import (
|
||||
monitoring "google.golang.org/api/monitoring/v3"
|
||||
)
|
||||
|
||||
const SourceType string = "cloud-monitoring"
|
||||
const SourceKind string = "cloud-monitoring"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,12 +50,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
// Initialize initializes a Cloud Monitoring Source instance.
|
||||
@@ -99,8 +99,8 @@ type Source struct {
|
||||
userAgent string
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
package cloudmonitoring_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -35,14 +35,14 @@ func TestParseFromYamlCloudMonitoring(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-cloud-monitoring-instance
|
||||
type: cloud-monitoring
|
||||
sources:
|
||||
my-cloud-monitoring-instance:
|
||||
kind: cloud-monitoring
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-cloud-monitoring-instance": cloudmonitoring.Config{
|
||||
Name: "my-cloud-monitoring-instance",
|
||||
Type: cloudmonitoring.SourceType,
|
||||
Kind: cloudmonitoring.SourceKind,
|
||||
UseClientOAuth: false,
|
||||
},
|
||||
},
|
||||
@@ -50,15 +50,15 @@ func TestParseFromYamlCloudMonitoring(t *testing.T) {
|
||||
{
|
||||
desc: "use client auth example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-cloud-monitoring-instance
|
||||
type: cloud-monitoring
|
||||
useClientOAuth: true
|
||||
sources:
|
||||
my-cloud-monitoring-instance:
|
||||
kind: cloud-monitoring
|
||||
useClientOAuth: true
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-cloud-monitoring-instance": cloudmonitoring.Config{
|
||||
Name: "my-cloud-monitoring-instance",
|
||||
Type: cloudmonitoring.SourceType,
|
||||
Kind: cloudmonitoring.SourceKind,
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
},
|
||||
@@ -68,12 +68,16 @@ func TestParseFromYamlCloudMonitoring(t *testing.T) {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -89,28 +93,36 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-cloud-monitoring-instance
|
||||
type: cloud-monitoring
|
||||
project: test-project
|
||||
sources:
|
||||
my-cloud-monitoring-instance:
|
||||
kind: cloud-monitoring
|
||||
project: test-project
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-cloud-monitoring-instance\" as \"cloud-monitoring\": [2:1] unknown field \"project\"\n 1 | name: my-cloud-monitoring-instance\n> 2 | project: test-project\n ^\n 3 | type: cloud-monitoring",
|
||||
err: `unable to parse source "my-cloud-monitoring-instance" as "cloud-monitoring": [2:1] unknown field "project"
|
||||
1 | kind: cloud-monitoring
|
||||
> 2 | project: test-project
|
||||
^
|
||||
`,
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-cloud-monitoring-instance
|
||||
useClientOAuth: true
|
||||
sources:
|
||||
my-cloud-monitoring-instance:
|
||||
useClientOAuth: true
|
||||
`,
|
||||
err: "error unmarshaling sources: missing 'type' field or it is not a string",
|
||||
err: "missing 'kind' field for source \"my-cloud-monitoring-instance\"",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ import (
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
)
|
||||
|
||||
const SourceType string = "cloud-sql-admin"
|
||||
const SourceKind string = "cloud-sql-admin"
|
||||
|
||||
var targetLinkRegex = regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/databases/([^/]+)`)
|
||||
|
||||
@@ -42,8 +42,8 @@ var targetLinkRegex = regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/da
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,13 +57,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
DefaultProject string `yaml:"defaultProject"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
// Initialize initializes a CloudSQL Admin Source instance.
|
||||
@@ -110,8 +110,8 @@ type Source struct {
|
||||
Service *sqladmin.Service
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -352,28 +352,6 @@ func (s *Source) GetWaitForOperations(ctx context.Context, service *sqladmin.Ser
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *Source) InsertBackupRun(ctx context.Context, project, instance, location, backupDescription, accessToken string) (any, error) {
|
||||
backupRun := &sqladmin.BackupRun{}
|
||||
if location != "" {
|
||||
backupRun.Location = location
|
||||
}
|
||||
if backupDescription != "" {
|
||||
backupRun.Description = backupDescription
|
||||
}
|
||||
|
||||
service, err := s.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := service.BackupRuns.Insert(project, instance, backupRun).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating backup: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func generateCloudSQLConnectionMessage(ctx context.Context, source *Source, logger log.Logger, opResponse map[string]any, connectionMessageTemplate string) (string, bool) {
|
||||
operationType, ok := opResponse["operationType"].(string)
|
||||
if !ok || operationType != "CREATE_DATABASE" {
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
package cloudsqladmin_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -35,14 +35,14 @@ func TestParseFromYamlCloudSQLAdmin(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-cloud-sql-admin-instance
|
||||
type: cloud-sql-admin
|
||||
sources:
|
||||
my-cloud-sql-admin-instance:
|
||||
kind: cloud-sql-admin
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-cloud-sql-admin-instance": cloudsqladmin.Config{
|
||||
Name: "my-cloud-sql-admin-instance",
|
||||
Type: cloudsqladmin.SourceType,
|
||||
Kind: cloudsqladmin.SourceKind,
|
||||
UseClientOAuth: false,
|
||||
},
|
||||
},
|
||||
@@ -50,15 +50,15 @@ func TestParseFromYamlCloudSQLAdmin(t *testing.T) {
|
||||
{
|
||||
desc: "use client auth example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-cloud-sql-admin-instance
|
||||
type: cloud-sql-admin
|
||||
useClientOAuth: true
|
||||
sources:
|
||||
my-cloud-sql-admin-instance:
|
||||
kind: cloud-sql-admin
|
||||
useClientOAuth: true
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-cloud-sql-admin-instance": cloudsqladmin.Config{
|
||||
Name: "my-cloud-sql-admin-instance",
|
||||
Type: cloudsqladmin.SourceType,
|
||||
Kind: cloudsqladmin.SourceKind,
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
},
|
||||
@@ -68,12 +68,16 @@ func TestParseFromYamlCloudSQLAdmin(t *testing.T) {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -89,28 +93,36 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-cloud-sql-admin-instance
|
||||
type: cloud-sql-admin
|
||||
project: test-project
|
||||
sources:
|
||||
my-cloud-sql-admin-instance:
|
||||
kind: cloud-sql-admin
|
||||
project: test-project
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-cloud-sql-admin-instance\" as \"cloud-sql-admin\": [2:1] unknown field \"project\"\n 1 | name: my-cloud-sql-admin-instance\n> 2 | project: test-project\n ^\n 3 | type: cloud-sql-admin",
|
||||
err: `unable to parse source "my-cloud-sql-admin-instance" as "cloud-sql-admin": [2:1] unknown field "project"
|
||||
1 | kind: cloud-sql-admin
|
||||
> 2 | project: test-project
|
||||
^
|
||||
`,
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-cloud-sql-admin-instance
|
||||
useClientOAuth: true
|
||||
sources:
|
||||
my-cloud-sql-admin-instance:
|
||||
useClientOAuth: true
|
||||
`,
|
||||
err: "error unmarshaling sources: missing 'type' field or it is not a string",
|
||||
err: "missing 'kind' field for source \"my-cloud-sql-admin-instance\"",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -29,14 +29,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "cloud-sql-mssql"
|
||||
const SourceKind string = "cloud-sql-mssql"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
type Config struct {
|
||||
// Cloud SQL MSSQL configs
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Region string `yaml:"region" validate:"required"`
|
||||
Instance string `yaml:"instance" validate:"required"`
|
||||
@@ -62,9 +62,9 @@ type Config struct {
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
// Returns Cloud SQL MSSQL source type
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
// Returns Cloud SQL MSSQL source kind
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -94,9 +94,9 @@ type Source struct {
|
||||
Db *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
// Returns Cloud SQL MSSQL source type
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
// Returns Cloud SQL MSSQL source kind
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -152,7 +152,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
|
||||
|
||||
func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
package cloudsqlmssql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -34,20 +33,20 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: cloud-sql-mssql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-instance:
|
||||
kind: cloud-sql-mssql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": cloudsqlmssql.Config{
|
||||
Name: "my-instance",
|
||||
Type: cloudsqlmssql.SourceType,
|
||||
Kind: cloudsqlmssql.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -61,21 +60,21 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
|
||||
{
|
||||
desc: "psc ipType",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: cloud-sql-mssql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
ipType: psc
|
||||
sources:
|
||||
my-instance:
|
||||
kind: cloud-sql-mssql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
ipType: psc
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": cloudsqlmssql.Config{
|
||||
Name: "my-instance",
|
||||
Type: cloudsqlmssql.SourceType,
|
||||
Kind: cloudsqlmssql.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -89,21 +88,21 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
|
||||
{
|
||||
desc: "with deprecated ipAddress",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: cloud-sql-mssql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipAddress: random
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-instance:
|
||||
kind: cloud-sql-mssql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipAddress: random
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": cloudsqlmssql.Config{
|
||||
Name: "my-instance",
|
||||
Type: cloudsqlmssql.SourceType,
|
||||
Kind: cloudsqlmssql.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -118,12 +117,16 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect psarse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect psarse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -139,53 +142,57 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "invalid ipType",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: cloud-sql-mssql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: fail
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-instance:
|
||||
kind: cloud-sql-mssql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: fail
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"cloud-sql-mssql\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
|
||||
err: "unable to parse source \"my-instance\" as \"cloud-sql-mssql\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
|
||||
},
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: cloud-sql-mssql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
sources:
|
||||
my-instance:
|
||||
kind: cloud-sql-mssql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"cloud-sql-mssql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | name: my-instance\n 5 | password: my_pass\n 6 | ",
|
||||
err: "unable to parse source \"my-instance\" as \"cloud-sql-mssql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | kind: cloud-sql-mssql\n 5 | password: my_pass\n 6 | ",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: cloud-sql-mssql
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-instance:
|
||||
kind: cloud-sql-mssql
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"cloud-sql-mssql\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-instance\" as \"cloud-sql-mssql\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -30,14 +30,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "cloud-sql-mysql"
|
||||
const SourceKind string = "cloud-sql-mysql"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Region string `yaml:"region" validate:"required"`
|
||||
Instance string `yaml:"instance" validate:"required"`
|
||||
@@ -61,8 +61,8 @@ type Config struct {
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -90,8 +90,8 @@ type Source struct {
|
||||
Pool *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -184,7 +184,7 @@ func getConnectionConfig(ctx context.Context, user, pass string) (string, string
|
||||
|
||||
func initCloudSQLMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
// Configure the driver to connect to the database
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
package cloudsqlmysql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -34,20 +33,20 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mysql-instance
|
||||
type: cloud-sql-mysql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-mysql-instance:
|
||||
kind: cloud-sql-mysql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-mysql-instance": cloudsqlmysql.Config{
|
||||
Name: "my-mysql-instance",
|
||||
Type: cloudsqlmysql.SourceType,
|
||||
Kind: cloudsqlmysql.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -61,21 +60,21 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
|
||||
{
|
||||
desc: "public ipType",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mysql-instance
|
||||
type: cloud-sql-mysql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: Public
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-mysql-instance:
|
||||
kind: cloud-sql-mysql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: Public
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-mysql-instance": cloudsqlmysql.Config{
|
||||
Name: "my-mysql-instance",
|
||||
Type: cloudsqlmysql.SourceType,
|
||||
Kind: cloudsqlmysql.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -89,21 +88,21 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
|
||||
{
|
||||
desc: "private ipType",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mysql-instance
|
||||
type: cloud-sql-mysql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: private
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-mysql-instance:
|
||||
kind: cloud-sql-mysql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: private
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-mysql-instance": cloudsqlmysql.Config{
|
||||
Name: "my-mysql-instance",
|
||||
Type: cloudsqlmysql.SourceType,
|
||||
Kind: cloudsqlmysql.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -117,21 +116,21 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
|
||||
{
|
||||
desc: "psc ipType",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mysql-instance
|
||||
type: cloud-sql-mysql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: psc
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-mysql-instance:
|
||||
kind: cloud-sql-mysql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: psc
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-mysql-instance": cloudsqlmysql.Config{
|
||||
Name: "my-mysql-instance",
|
||||
Type: cloudsqlmysql.SourceType,
|
||||
Kind: cloudsqlmysql.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -145,12 +144,16 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -166,53 +169,57 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "invalid ipType",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mysql-instance
|
||||
type: cloud-sql-mysql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: fail
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-mysql-instance:
|
||||
kind: cloud-sql-mysql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: fail
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-mysql-instance\" as \"cloud-sql-mysql\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
|
||||
err: "unable to parse source \"my-mysql-instance\" as \"cloud-sql-mysql\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
|
||||
},
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mysql-instance
|
||||
type: cloud-sql-mysql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
sources:
|
||||
my-mysql-instance:
|
||||
kind: cloud-sql-mysql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-mysql-instance\" as \"cloud-sql-mysql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | name: my-mysql-instance\n 5 | password: my_pass\n 6 | ",
|
||||
err: "unable to parse source \"my-mysql-instance\" as \"cloud-sql-mysql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | kind: cloud-sql-mysql\n 5 | password: my_pass\n 6 | ",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mysql-instance
|
||||
type: cloud-sql-mysql
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-mysql-instance:
|
||||
kind: cloud-sql-mysql
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-mysql-instance\" as \"cloud-sql-mysql\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-mysql-instance\" as \"cloud-sql-mysql\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -28,14 +28,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "cloud-sql-postgres"
|
||||
const SourceKind string = "cloud-sql-postgres"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Region string `yaml:"region" validate:"required"`
|
||||
Instance string `yaml:"instance" validate:"required"`
|
||||
@@ -59,8 +59,8 @@ type Config struct {
|
||||
Password string `yaml:"password"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -88,8 +88,8 @@ type Source struct {
|
||||
Pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -162,7 +162,7 @@ func getConnectionConfig(ctx context.Context, user, pass, dbname string) (string
|
||||
|
||||
func initCloudSQLPgConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
// Configure the driver to connect to the database
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
package cloudsqlpg_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -34,20 +33,20 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpg.Config{
|
||||
Name: "my-pg-instance",
|
||||
Type: cloudsqlpg.SourceType,
|
||||
Kind: cloudsqlpg.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -61,21 +60,21 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
|
||||
{
|
||||
desc: "public ipType",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: Public
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: Public
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpg.Config{
|
||||
Name: "my-pg-instance",
|
||||
Type: cloudsqlpg.SourceType,
|
||||
Kind: cloudsqlpg.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -89,21 +88,21 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
|
||||
{
|
||||
desc: "private ipType",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: private
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: private
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpg.Config{
|
||||
Name: "my-pg-instance",
|
||||
Type: cloudsqlpg.SourceType,
|
||||
Kind: cloudsqlpg.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -117,21 +116,21 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
|
||||
{
|
||||
desc: "psc ipType",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: psc
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: psc
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpg.Config{
|
||||
Name: "my-pg-instance",
|
||||
Type: cloudsqlpg.SourceType,
|
||||
Kind: cloudsqlpg.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -145,12 +144,16 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -166,53 +169,57 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "invalid ipType",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: fail
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: fail
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"cloud-sql-postgres\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
|
||||
err: "unable to parse source \"my-pg-instance\" as \"cloud-sql-postgres\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
|
||||
},
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"cloud-sql-postgres\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | name: my-pg-instance\n 5 | password: my_pass\n 6 | ",
|
||||
err: "unable to parse source \"my-pg-instance\" as \"cloud-sql-postgres\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | kind: cloud-sql-postgres\n 5 | password: my_pass\n 6 | ",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"cloud-sql-postgres\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-pg-instance\" as \"cloud-sql-postgres\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -29,14 +29,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "couchbase"
|
||||
const SourceKind string = "couchbase"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
ConnectionString string `yaml:"connectionString" validate:"required"`
|
||||
Bucket string `yaml:"bucket" validate:"required"`
|
||||
Scope string `yaml:"scope" validate:"required"`
|
||||
@@ -66,8 +66,8 @@ type Config struct {
|
||||
QueryScanConsistency uint `yaml:"queryScanConsistency"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -96,8 +96,8 @@ type Source struct {
|
||||
Scope *gocb.Scope
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
package couchbase_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/couchbase"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -34,19 +33,19 @@ func TestParseFromYamlCouchbase(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-couchbase-instance
|
||||
type: couchbase
|
||||
connectionString: localhost
|
||||
username: Administrator
|
||||
password: password
|
||||
bucket: travel-sample
|
||||
scope: inventory
|
||||
sources:
|
||||
my-couchbase-instance:
|
||||
kind: couchbase
|
||||
connectionString: localhost
|
||||
username: Administrator
|
||||
password: password
|
||||
bucket: travel-sample
|
||||
scope: inventory
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-couchbase-instance": couchbase.Config{
|
||||
Name: "my-couchbase-instance",
|
||||
Type: couchbase.SourceType,
|
||||
Kind: couchbase.SourceKind,
|
||||
ConnectionString: "localhost",
|
||||
Username: "Administrator",
|
||||
Password: "password",
|
||||
@@ -58,24 +57,24 @@ func TestParseFromYamlCouchbase(t *testing.T) {
|
||||
{
|
||||
desc: "with TLS configuration",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-couchbase-instance
|
||||
type: couchbase
|
||||
connectionString: couchbases://localhost
|
||||
bucket: travel-sample
|
||||
scope: inventory
|
||||
clientCert: /path/to/cert.pem
|
||||
clientKey: /path/to/key.pem
|
||||
clientCertPassword: password
|
||||
clientKeyPassword: password
|
||||
caCert: /path/to/ca.pem
|
||||
noSslVerify: false
|
||||
queryScanConsistency: 2
|
||||
sources:
|
||||
my-couchbase-instance:
|
||||
kind: couchbase
|
||||
connectionString: couchbases://localhost
|
||||
bucket: travel-sample
|
||||
scope: inventory
|
||||
clientCert: /path/to/cert.pem
|
||||
clientKey: /path/to/key.pem
|
||||
clientCertPassword: password
|
||||
clientKeyPassword: password
|
||||
caCert: /path/to/ca.pem
|
||||
noSslVerify: false
|
||||
queryScanConsistency: 2
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-couchbase-instance": couchbase.Config{
|
||||
Name: "my-couchbase-instance",
|
||||
Type: couchbase.SourceType,
|
||||
Kind: couchbase.SourceKind,
|
||||
ConnectionString: "couchbases://localhost",
|
||||
Bucket: "travel-sample",
|
||||
Scope: "inventory",
|
||||
@@ -92,12 +91,16 @@ func TestParseFromYamlCouchbase(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -112,35 +115,39 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-couchbase-instance
|
||||
type: couchbase
|
||||
connectionString: localhost
|
||||
username: Administrator
|
||||
password: password
|
||||
bucket: travel-sample
|
||||
scope: inventory
|
||||
foo: bar
|
||||
sources:
|
||||
my-couchbase-instance:
|
||||
kind: couchbase
|
||||
connectionString: localhost
|
||||
username: Administrator
|
||||
password: password
|
||||
bucket: travel-sample
|
||||
scope: inventory
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-couchbase-instance\" as \"couchbase\": [3:1] unknown field \"foo\"\n 1 | bucket: travel-sample\n 2 | connectionString: localhost\n> 3 | foo: bar\n ^\n 4 | name: my-couchbase-instance\n 5 | password: password\n 6 | scope: inventory\n 7 | ",
|
||||
err: "unable to parse source \"my-couchbase-instance\" as \"couchbase\": [3:1] unknown field \"foo\"\n 1 | bucket: travel-sample\n 2 | connectionString: localhost\n> 3 | foo: bar\n ^\n 4 | kind: couchbase\n 5 | password: password\n 6 | scope: inventory\n 7 | ",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-couchbase-instance
|
||||
type: couchbase
|
||||
username: Administrator
|
||||
password: password
|
||||
bucket: travel-sample
|
||||
scope: inventory
|
||||
sources:
|
||||
my-couchbase-instance:
|
||||
kind: couchbase
|
||||
username: Administrator
|
||||
password: password
|
||||
bucket: travel-sample
|
||||
scope: inventory
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-couchbase-instance\" as \"couchbase\": Key: 'Config.ConnectionString' Error:Field validation for 'ConnectionString' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-couchbase-instance\" as \"couchbase\": Key: 'Config.ConnectionString' Error:Field validation for 'ConnectionString' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -29,14 +29,14 @@ import (
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
const SourceType string = "dataplex"
|
||||
const SourceKind string = "dataplex"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,13 +51,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
type Config struct {
|
||||
// Dataplex configs
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
// Returns Dataplex source type
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
// Returns Dataplex source kind
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -81,9 +81,9 @@ type Source struct {
|
||||
Client *dataplexapi.CatalogClient
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
// Returns Dataplex source type
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
// Returns Dataplex source kind
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -104,7 +104,7 @@ func initDataplexConnection(
|
||||
name string,
|
||||
project string,
|
||||
) (*dataplexapi.CatalogClient, error) {
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
cred, err := google.FindDefaultCredentials(ctx)
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
package dataplex_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/dataplex"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -34,15 +33,15 @@ func TestParseFromYamlDataplex(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: dataplex
|
||||
project: my-project
|
||||
sources:
|
||||
my-instance:
|
||||
kind: dataplex
|
||||
project: my-project
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": dataplex.Config{
|
||||
Name: "my-instance",
|
||||
Type: dataplex.SourceType,
|
||||
Kind: dataplex.SourceKind,
|
||||
Project: "my-project",
|
||||
},
|
||||
},
|
||||
@@ -50,12 +49,16 @@ func TestParseFromYamlDataplex(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -71,27 +74,31 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: dataplex
|
||||
project: my-project
|
||||
foo: bar
|
||||
sources:
|
||||
my-instance:
|
||||
kind: dataplex
|
||||
project: my-project
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"dataplex\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | name: my-instance\n 3 | project: my-project\n 4 | type: dataplex",
|
||||
err: "unable to parse source \"my-instance\" as \"dataplex\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | kind: dataplex\n 3 | project: my-project",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: dataplex
|
||||
sources:
|
||||
my-instance:
|
||||
kind: dataplex
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"dataplex\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-instance\" as \"dataplex\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -30,14 +30,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "dgraph"
|
||||
const SourceKind string = "dgraph"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ type DgraphClient struct {
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
DgraphUrl string `yaml:"dgraphUrl" validate:"required"`
|
||||
User string `yaml:"user"`
|
||||
Password string `yaml:"password"`
|
||||
@@ -75,8 +75,8 @@ type Config struct {
|
||||
ApiKey string `yaml:"apiKey"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -103,8 +103,8 @@ type Source struct {
|
||||
Client *DgraphClient `yaml:"client"`
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -139,7 +139,7 @@ func (s *Source) RunSQL(statement string, params parameters.ParamValues, isQuery
|
||||
|
||||
func initDgraphHttpClient(ctx context.Context, tracer trace.Tracer, r Config) (*DgraphClient, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, r.Name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, r.Name)
|
||||
defer span.End()
|
||||
|
||||
if r.DgraphUrl == "" {
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
package dgraph_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/dgraph"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -34,19 +33,19 @@ func TestParseFromYamlDgraph(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-dgraph-instance
|
||||
type: dgraph
|
||||
dgraphUrl: https://localhost:8080
|
||||
apiKey: abc123
|
||||
password: pass@123
|
||||
namespace: 0
|
||||
user: user123
|
||||
sources:
|
||||
my-dgraph-instance:
|
||||
kind: dgraph
|
||||
dgraphUrl: https://localhost:8080
|
||||
apiKey: abc123
|
||||
password: pass@123
|
||||
namespace: 0
|
||||
user: user123
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-dgraph-instance": dgraph.Config{
|
||||
Name: "my-dgraph-instance",
|
||||
Type: dgraph.SourceType,
|
||||
Kind: dgraph.SourceKind,
|
||||
DgraphUrl: "https://localhost:8080",
|
||||
ApiKey: "abc123",
|
||||
Password: "pass@123",
|
||||
@@ -58,15 +57,15 @@ func TestParseFromYamlDgraph(t *testing.T) {
|
||||
{
|
||||
desc: "basic example minimal field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-dgraph-instance
|
||||
type: dgraph
|
||||
dgraphUrl: https://localhost:8080
|
||||
sources:
|
||||
my-dgraph-instance:
|
||||
kind: dgraph
|
||||
dgraphUrl: https://localhost:8080
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-dgraph-instance": dgraph.Config{
|
||||
Name: "my-dgraph-instance",
|
||||
Type: dgraph.SourceType,
|
||||
Kind: dgraph.SourceKind,
|
||||
DgraphUrl: "https://localhost:8080",
|
||||
},
|
||||
},
|
||||
@@ -75,12 +74,16 @@ func TestParseFromYamlDgraph(t *testing.T) {
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tc.want, got); diff != "" {
|
||||
if diff := cmp.Diff(tc.want, got.Sources); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
@@ -97,27 +100,31 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-dgraph-instance
|
||||
type: dgraph
|
||||
dgraphUrl: https://localhost:8080
|
||||
foo: bar
|
||||
sources:
|
||||
my-dgraph-instance:
|
||||
kind: dgraph
|
||||
dgraphUrl: https://localhost:8080
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-dgraph-instance\" as \"dgraph\": [2:1] unknown field \"foo\"\n 1 | dgraphUrl: https://localhost:8080\n> 2 | foo: bar\n ^\n 3 | name: my-dgraph-instance\n 4 | type: dgraph",
|
||||
err: "unable to parse source \"my-dgraph-instance\" as \"dgraph\": [2:1] unknown field \"foo\"\n 1 | dgraphUrl: https://localhost:8080\n> 2 | foo: bar\n ^\n 3 | kind: dgraph",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-dgraph-instance
|
||||
type: dgraph
|
||||
sources:
|
||||
my-dgraph-instance:
|
||||
kind: dgraph
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-dgraph-instance\" as \"dgraph\": Key: 'Config.DgraphUrl' Error:Field validation for 'DgraphUrl' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-dgraph-instance\" as \"dgraph\": Key: 'Config.DgraphUrl' Error:Field validation for 'DgraphUrl' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -30,14 +30,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "elasticsearch"
|
||||
const SourceKind string = "elasticsearch"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,15 +51,15 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Addresses []string `yaml:"addresses" validate:"required"`
|
||||
Username string `yaml:"username"`
|
||||
Password string `yaml:"password"`
|
||||
APIKey string `yaml:"apikey"`
|
||||
}
|
||||
|
||||
func (c Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (c Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
type EsClient interface {
|
||||
@@ -139,9 +139,9 @@ func (c Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// SourceType returns the resourceType string for this source.
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
// SourceKind returns the kind string for this source.
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
|
||||
@@ -15,15 +15,13 @@
|
||||
package elasticsearch_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/elasticsearch"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlElasticsearch(t *testing.T) {
|
||||
@@ -35,17 +33,17 @@ func TestParseFromYamlElasticsearch(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-es-instance
|
||||
type: elasticsearch
|
||||
addresses:
|
||||
- http://localhost:9200
|
||||
apikey: somekey
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
sources:
|
||||
my-es-instance:
|
||||
kind: elasticsearch
|
||||
addresses:
|
||||
- http://localhost:9200
|
||||
apikey: somekey
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-es-instance": elasticsearch.Config{
|
||||
Name: "my-es-instance",
|
||||
Type: elasticsearch.SourceType,
|
||||
Kind: elasticsearch.SourceKind,
|
||||
Addresses: []string{"http://localhost:9200"},
|
||||
APIKey: "somekey",
|
||||
},
|
||||
@@ -54,50 +52,20 @@ func TestParseFromYamlElasticsearch(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
err := yaml.Unmarshal([]byte(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse yaml: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got); diff != "" {
|
||||
if diff := cmp.Diff(tc.want, got.Sources); diff != "" {
|
||||
t.Errorf("unexpected config diff (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailParseFromYaml(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-es-instance
|
||||
type: elasticsearch
|
||||
addresses:
|
||||
- http://localhost:9200
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-es-instance\" as \"elasticsearch\": [3:1] unknown field \"foo\"\n 1 | addresses:\n 2 | - http://localhost:9200\n> 3 | foo: bar\n ^\n 4 | name: my-es-instance\n 5 | type: elasticsearch",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if errStr != tc.err {
|
||||
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTool_esqlToMap(t1 *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -27,13 +27,13 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
)
|
||||
|
||||
const SourceType string = "firebird"
|
||||
const SourceKind string = "firebird"
|
||||
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
@@ -55,8 +55,8 @@ type Config struct {
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -84,8 +84,8 @@ type Source struct {
|
||||
Db *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -144,7 +144,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
|
||||
}
|
||||
|
||||
func initFirebirdConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string) (*sql.DB, error) {
|
||||
_, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
_, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
// urlExample := "user:password@host:port/path/to/database.fdb"
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
package firebird_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/firebird"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -34,19 +33,19 @@ func TestParseFromYamlFirebird(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-fdb-instance
|
||||
type: firebird
|
||||
host: my-host
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-fdb-instance:
|
||||
kind: firebird
|
||||
host: my-host
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-fdb-instance": firebird.Config{
|
||||
Name: "my-fdb-instance",
|
||||
Type: firebird.SourceType,
|
||||
Kind: firebird.SourceKind,
|
||||
Host: "my-host",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
@@ -58,12 +57,16 @@ func TestParseFromYamlFirebird(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -79,35 +82,39 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-fdb-instance
|
||||
type: firebird
|
||||
host: my-host
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
sources:
|
||||
my-fdb-instance:
|
||||
kind: firebird
|
||||
host: my-host
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-fdb-instance\" as \"firebird\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: my-host\n 4 | name: my-fdb-instance\n 5 | password: my_pass\n 6 | ",
|
||||
err: "unable to parse source \"my-fdb-instance\" as \"firebird\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: my-host\n 4 | kind: firebird\n 5 | password: my_pass\n 6 | ",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-fdb-instance
|
||||
type: firebird
|
||||
host: my-host
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
sources:
|
||||
my-fdb-instance:
|
||||
kind: firebird
|
||||
host: my-host
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-fdb-instance\" as \"firebird\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-fdb-instance\" as \"firebird\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -31,14 +31,14 @@ import (
|
||||
"google.golang.org/genproto/googleapis/type/latlng"
|
||||
)
|
||||
|
||||
const SourceType string = "firestore"
|
||||
const SourceKind string = "firestore"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,14 +53,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
type Config struct {
|
||||
// Firestore configs
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Database string `yaml:"database"` // Optional, defaults to "(default)"
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
// Returns Firestore source type
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
// Returns Firestore source kind
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -92,9 +92,9 @@ type Source struct {
|
||||
RulesClient *firebaserules.Service
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
// Returns Firestore source type
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
// Returns Firestore source kind
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -594,7 +594,7 @@ func initFirestoreConnection(
|
||||
project string,
|
||||
database string,
|
||||
) (*firestore.Client, error) {
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
|
||||
@@ -15,13 +15,12 @@
|
||||
package firestore_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -35,15 +34,15 @@ func TestParseFromYamlFirestore(t *testing.T) {
|
||||
{
|
||||
desc: "basic example with default database",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-firestore
|
||||
type: firestore
|
||||
project: my-project
|
||||
sources:
|
||||
my-firestore:
|
||||
kind: firestore
|
||||
project: my-project
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-firestore": firestore.Config{
|
||||
Name: "my-firestore",
|
||||
Type: firestore.SourceType,
|
||||
Kind: firestore.SourceKind,
|
||||
Project: "my-project",
|
||||
Database: "",
|
||||
},
|
||||
@@ -52,16 +51,16 @@ func TestParseFromYamlFirestore(t *testing.T) {
|
||||
{
|
||||
desc: "with custom database",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-firestore
|
||||
type: firestore
|
||||
project: my-project
|
||||
database: my-database
|
||||
sources:
|
||||
my-firestore:
|
||||
kind: firestore
|
||||
project: my-project
|
||||
database: my-database
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-firestore": firestore.Config{
|
||||
Name: "my-firestore",
|
||||
Type: firestore.SourceType,
|
||||
Kind: firestore.SourceKind,
|
||||
Project: "my-project",
|
||||
Database: "my-database",
|
||||
},
|
||||
@@ -70,18 +69,22 @@ func TestParseFromYamlFirestore(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailParseFromYaml(t *testing.T) {
|
||||
func TestFailParseFromYamlFirestore(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
@@ -90,27 +93,32 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-firestore
|
||||
type: firestore
|
||||
project: my-project
|
||||
foo: bar
|
||||
sources:
|
||||
my-firestore:
|
||||
kind: firestore
|
||||
project: my-project
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-firestore\" as \"firestore\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | name: my-firestore\n 3 | project: my-project\n 4 | type: firestore",
|
||||
err: "unable to parse source \"my-firestore\" as \"firestore\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | kind: firestore\n 3 | project: my-project",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-firestore
|
||||
type: firestore
|
||||
sources:
|
||||
my-firestore:
|
||||
kind: firestore
|
||||
database: my-database
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-firestore\" as \"firestore\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-firestore\" as \"firestore\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -29,14 +29,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "http"
|
||||
const SourceKind string = "http"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
BaseURL string `yaml:"baseUrl"`
|
||||
Timeout string `yaml:"timeout"`
|
||||
DefaultHeaders map[string]string `yaml:"headers"`
|
||||
@@ -58,8 +58,8 @@ type Config struct {
|
||||
DisableSslVerification bool `yaml:"disableSslVerification"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
// Initialize initializes an HTTP Source instance.
|
||||
@@ -122,8 +122,8 @@ type Source struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
package http_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -34,15 +34,15 @@ func TestParseFromYamlHttp(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-http-instance
|
||||
type: http
|
||||
baseUrl: http://test_server/
|
||||
sources:
|
||||
my-http-instance:
|
||||
kind: http
|
||||
baseUrl: http://test_server/
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-http-instance": http.Config{
|
||||
Name: "my-http-instance",
|
||||
Type: http.SourceType,
|
||||
Kind: http.SourceKind,
|
||||
BaseURL: "http://test_server/",
|
||||
Timeout: "30s",
|
||||
DisableSslVerification: false,
|
||||
@@ -52,23 +52,23 @@ func TestParseFromYamlHttp(t *testing.T) {
|
||||
{
|
||||
desc: "advanced example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-http-instance
|
||||
type: http
|
||||
baseUrl: http://test_server/
|
||||
timeout: 10s
|
||||
headers:
|
||||
Authorization: test_header
|
||||
Custom-Header: custom
|
||||
queryParams:
|
||||
api-key: test_api_key
|
||||
param: param-value
|
||||
disableSslVerification: true
|
||||
sources:
|
||||
my-http-instance:
|
||||
kind: http
|
||||
baseUrl: http://test_server/
|
||||
timeout: 10s
|
||||
headers:
|
||||
Authorization: test_header
|
||||
Custom-Header: custom
|
||||
queryParams:
|
||||
api-key: test_api_key
|
||||
param: param-value
|
||||
disableSslVerification: true
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-http-instance": http.Config{
|
||||
Name: "my-http-instance",
|
||||
Type: http.SourceType,
|
||||
Kind: http.SourceKind,
|
||||
BaseURL: "http://test_server/",
|
||||
Timeout: "10s",
|
||||
DefaultHeaders: map[string]string{"Authorization": "test_header", "Custom-Header": "custom"},
|
||||
@@ -80,12 +80,16 @@ func TestParseFromYamlHttp(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -100,32 +104,36 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-http-instance
|
||||
type: http
|
||||
baseUrl: http://test_server/
|
||||
timeout: 10s
|
||||
headers:
|
||||
Authorization: test_header
|
||||
queryParams:
|
||||
api-key: test_api_key
|
||||
project: test-project
|
||||
sources:
|
||||
my-http-instance:
|
||||
kind: http
|
||||
baseUrl: http://test_server/
|
||||
timeout: 10s
|
||||
headers:
|
||||
Authorization: test_header
|
||||
queryParams:
|
||||
api-key: test_api_key
|
||||
project: test-project
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-http-instance\" as \"http\": [5:1] unknown field \"project\"\n 2 | headers:\n 3 | Authorization: test_header\n 4 | name: my-http-instance\n> 5 | project: test-project\n ^\n 6 | queryParams:\n 7 | api-key: test_api_key\n 8 | timeout: 10s\n 9 | ",
|
||||
err: "unable to parse source \"my-http-instance\" as \"http\": [5:1] unknown field \"project\"\n 2 | headers:\n 3 | Authorization: test_header\n 4 | kind: http\n> 5 | project: test-project\n ^\n 6 | queryParams:\n 7 | api-key: test_api_key\n 8 | timeout: 10s",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-http-instance
|
||||
baseUrl: http://test_server/
|
||||
sources:
|
||||
my-http-instance:
|
||||
baseUrl: http://test_server/
|
||||
`,
|
||||
err: "error unmarshaling sources: missing 'type' field or it is not a string",
|
||||
err: "missing 'kind' field for source \"my-http-instance\"",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -33,14 +33,14 @@ import (
|
||||
v4 "github.com/looker-open-source/sdk-codegen/go/sdk/v4"
|
||||
)
|
||||
|
||||
const SourceType string = "looker"
|
||||
const SourceKind string = "looker"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
BaseURL string `yaml:"base_url" validate:"required"`
|
||||
ClientId string `yaml:"client_id"`
|
||||
ClientSecret string `yaml:"client_secret"`
|
||||
@@ -79,8 +79,8 @@ type Config struct {
|
||||
SessionLength int64 `yaml:"sessionLength"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
// Initialize initializes a Looker Source instance.
|
||||
@@ -154,8 +154,8 @@ type Source struct {
|
||||
AuthTokenHeaderName string
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
package looker_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -34,17 +34,17 @@ func TestParseFromYamlLooker(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-looker-instance
|
||||
type: looker
|
||||
base_url: http://example.looker.com/
|
||||
client_id: jasdl;k;tjl
|
||||
client_secret: sdakl;jgflkasdfkfg
|
||||
sources:
|
||||
my-looker-instance:
|
||||
kind: looker
|
||||
base_url: http://example.looker.com/
|
||||
client_id: jasdl;k;tjl
|
||||
client_secret: sdakl;jgflkasdfkfg
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-looker-instance": looker.Config{
|
||||
Name: "my-looker-instance",
|
||||
Type: looker.SourceType,
|
||||
Kind: looker.SourceKind,
|
||||
BaseURL: "http://example.looker.com/",
|
||||
ClientId: "jasdl;k;tjl",
|
||||
ClientSecret: "sdakl;jgflkasdfkfg",
|
||||
@@ -62,18 +62,22 @@ func TestParseFromYamlLooker(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailParseFromYaml(t *testing.T) {
|
||||
func TestFailParseFromYamlLooker(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
@@ -82,30 +86,34 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-looker-instance
|
||||
type: looker
|
||||
base_url: http://example.looker.com/
|
||||
client_id: jasdl;k;tjl
|
||||
client_secret: sdakl;jgflkasdfkfg
|
||||
schema: test-schema
|
||||
sources:
|
||||
my-looker-instance:
|
||||
kind: looker
|
||||
base_url: http://example.looker.com/
|
||||
client_id: jasdl;k;tjl
|
||||
client_secret: sdakl;jgflkasdfkfg
|
||||
schema: test-schema
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-looker-instance\" as \"looker\": [5:1] unknown field \"schema\"\n 2 | client_id: jasdl;k;tjl\n 3 | client_secret: sdakl;jgflkasdfkfg\n 4 | name: my-looker-instance\n> 5 | schema: test-schema\n ^\n 6 | type: looker",
|
||||
err: "unable to parse source \"my-looker-instance\" as \"looker\": [5:1] unknown field \"schema\"\n 2 | client_id: jasdl;k;tjl\n 3 | client_secret: sdakl;jgflkasdfkfg\n 4 | kind: looker\n> 5 | schema: test-schema\n ^\n",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-looker-instance
|
||||
type: looker
|
||||
client_id: jasdl;k;tjl
|
||||
sources:
|
||||
my-looker-instance:
|
||||
kind: looker
|
||||
client_id: jasdl;k;tjl
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-looker-instance\" as \"looker\": Key: 'Config.BaseURL' Error:Field validation for 'BaseURL' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-looker-instance\" as \"looker\": Key: 'Config.BaseURL' Error:Field validation for 'BaseURL' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -27,14 +27,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "mindsdb"
|
||||
const SourceKind string = "mindsdb"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
@@ -57,8 +57,8 @@ type Config struct {
|
||||
QueryTimeout string `yaml:"queryTimeout"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -86,8 +86,8 @@ type Source struct {
|
||||
Pool *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -159,7 +159,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
|
||||
|
||||
func initMindsDBConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
// Configure the driver to connect to the database
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
package mindsdb_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/mindsdb"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -34,19 +33,19 @@ func TestParseFromYamlMindsDB(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mindsdb-instance
|
||||
type: mindsdb
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-mindsdb-instance:
|
||||
kind: mindsdb
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-mindsdb-instance": mindsdb.Config{
|
||||
Name: "my-mindsdb-instance",
|
||||
Type: mindsdb.SourceType,
|
||||
Kind: mindsdb.SourceKind,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
@@ -58,20 +57,20 @@ func TestParseFromYamlMindsDB(t *testing.T) {
|
||||
{
|
||||
desc: "with query timeout",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mindsdb-instance
|
||||
type: mindsdb
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
queryTimeout: 45s
|
||||
sources:
|
||||
my-mindsdb-instance:
|
||||
kind: mindsdb
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
queryTimeout: 45s
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-mindsdb-instance": mindsdb.Config{
|
||||
Name: "my-mindsdb-instance",
|
||||
Type: mindsdb.SourceType,
|
||||
Kind: mindsdb.SourceKind,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
@@ -84,12 +83,16 @@ func TestParseFromYamlMindsDB(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -105,35 +108,39 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mindsdb-instance
|
||||
type: mindsdb
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
sources:
|
||||
my-mindsdb-instance:
|
||||
kind: mindsdb
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-mindsdb-instance\" as \"mindsdb\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | name: my-mindsdb-instance\n 5 | password: my_pass\n 6 | ",
|
||||
err: "unable to parse source \"my-mindsdb-instance\" as \"mindsdb\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | kind: mindsdb\n 5 | password: my_pass\n 6 | ",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mindsdb-instance
|
||||
type: mindsdb
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-mindsdb-instance:
|
||||
kind: mindsdb
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-mindsdb-instance\" as \"mindsdb\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-mindsdb-instance\" as \"mindsdb\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -29,14 +29,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "mongodb"
|
||||
const SourceKind string = "mongodb"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,12 +50,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Uri string `yaml:"uri" validate:"required"` // MongoDB Atlas connection URI
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -84,8 +84,8 @@ type Source struct {
|
||||
Client *mongo.Client
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -293,7 +293,7 @@ func (s *Source) DeleteOne(ctx context.Context, filterString, database, collecti
|
||||
|
||||
func initMongoDBClient(ctx context.Context, tracer trace.Tracer, name, uri string) (*mongo.Client, error) {
|
||||
// Start a tracing span
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
package mongodb_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/mongodb"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -34,15 +33,15 @@ func TestParseFromYamlMongoDB(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: mongo-db
|
||||
type: "mongodb"
|
||||
uri: "mongodb+srv://username:password@host/dbname"
|
||||
sources:
|
||||
mongo-db:
|
||||
kind: "mongodb"
|
||||
uri: "mongodb+srv://username:password@host/dbname"
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"mongo-db": mongodb.Config{
|
||||
Name: "mongo-db",
|
||||
Type: mongodb.SourceType,
|
||||
Kind: mongodb.SourceKind,
|
||||
Uri: "mongodb+srv://username:password@host/dbname",
|
||||
},
|
||||
},
|
||||
@@ -50,12 +49,16 @@ func TestParseFromYamlMongoDB(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -71,27 +74,31 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: mongo-db
|
||||
type: mongodb
|
||||
uri: "mongodb+srv://username:password@host/dbname"
|
||||
foo: bar
|
||||
sources:
|
||||
mongo-db:
|
||||
kind: mongodb
|
||||
uri: "mongodb+srv://username:password@host/dbname"
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"mongo-db\" as \"mongodb\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | name: mongo-db\n 3 | type: mongodb\n 4 | uri: mongodb+srv://username:password@host/dbname",
|
||||
err: "unable to parse source \"mongo-db\" as \"mongodb\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | kind: mongodb\n 3 | uri: mongodb+srv://username:password@host/dbname",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: mongo-db
|
||||
type: mongodb
|
||||
sources:
|
||||
mongo-db:
|
||||
kind: mongodb
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"mongo-db\" as \"mongodb\": Key: 'Config.Uri' Error:Field validation for 'Uri' failed on the 'required' tag",
|
||||
err: "unable to parse source \"mongo-db\" as \"mongodb\": Key: 'Config.Uri' Error:Field validation for 'Uri' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -28,14 +28,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "mssql"
|
||||
const SourceKind string = "mssql"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
type Config struct {
|
||||
// Cloud SQL MSSQL configs
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
@@ -59,9 +59,9 @@ type Config struct {
|
||||
Encrypt string `yaml:"encrypt"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
// Returns Cloud SQL MSSQL source type
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
// Returns Cloud SQL MSSQL source kind
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -91,9 +91,9 @@ type Source struct {
|
||||
Db *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
// Returns Cloud SQL MSSQL source type
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
// Returns Cloud SQL MSSQL source kind
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -156,7 +156,7 @@ func initMssqlConnection(
|
||||
error,
|
||||
) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
package mssql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/mssql"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -34,19 +33,19 @@ func TestParseFromYamlMssql(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mssql-instance
|
||||
type: mssql
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-mssql-instance:
|
||||
kind: mssql
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-mssql-instance": mssql.Config{
|
||||
Name: "my-mssql-instance",
|
||||
Type: mssql.SourceType,
|
||||
Kind: mssql.SourceKind,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
@@ -58,20 +57,20 @@ func TestParseFromYamlMssql(t *testing.T) {
|
||||
{
|
||||
desc: "with encrypt field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mssql-instance
|
||||
type: mssql
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
encrypt: strict
|
||||
sources:
|
||||
my-mssql-instance:
|
||||
kind: mssql
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
encrypt: strict
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-mssql-instance": mssql.Config{
|
||||
Name: "my-mssql-instance",
|
||||
Type: mssql.SourceType,
|
||||
Kind: mssql.SourceKind,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
@@ -84,12 +83,16 @@ func TestParseFromYamlMssql(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect psarse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect psarse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -104,35 +107,39 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mssql-instance
|
||||
type: mssql
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
sources:
|
||||
my-mssql-instance:
|
||||
kind: mssql
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-mssql-instance\" as \"mssql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | name: my-mssql-instance\n 5 | password: my_pass\n 6 | ",
|
||||
err: "unable to parse source \"my-mssql-instance\" as \"mssql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | kind: mssql\n 5 | password: my_pass\n 6 | ",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mssql-instance
|
||||
type: mssql
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
sources:
|
||||
my-mssql-instance:
|
||||
kind: mssql
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-mssql-instance\" as \"mssql\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-mssql-instance\" as \"mssql\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -30,14 +30,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "mysql"
|
||||
const SourceKind string = "mysql"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
@@ -61,8 +61,8 @@ type Config struct {
|
||||
QueryParams map[string]string `yaml:"queryParams"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -90,8 +90,8 @@ type Source struct {
|
||||
Pool *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -158,7 +158,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
|
||||
|
||||
func initMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string, queryParams map[string]string) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
// Build query parameters via url.Values for deterministic order and proper escaping.
|
||||
|
||||
@@ -19,12 +19,12 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"go.opentelemetry.io/otel/trace/noop"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/mysql"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -38,19 +38,19 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mysql-instance
|
||||
type: mysql
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-mysql-instance:
|
||||
kind: mysql
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-mysql-instance": mysql.Config{
|
||||
Name: "my-mysql-instance",
|
||||
Type: mysql.SourceType,
|
||||
Kind: mysql.SourceKind,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
@@ -62,20 +62,20 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
|
||||
{
|
||||
desc: "with query timeout",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mysql-instance
|
||||
type: mysql
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
queryTimeout: 45s
|
||||
sources:
|
||||
my-mysql-instance:
|
||||
kind: mysql
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
queryTimeout: 45s
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-mysql-instance": mysql.Config{
|
||||
Name: "my-mysql-instance",
|
||||
Type: mysql.SourceType,
|
||||
Kind: mysql.SourceKind,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
@@ -88,22 +88,22 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
|
||||
{
|
||||
desc: "with query params",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mysql-instance
|
||||
type: mysql
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
queryParams:
|
||||
tls: preferred
|
||||
charset: utf8mb4
|
||||
sources:
|
||||
my-mysql-instance:
|
||||
kind: mysql
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
queryParams:
|
||||
tls: preferred
|
||||
charset: utf8mb4
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-mysql-instance": mysql.Config{
|
||||
Name: "my-mysql-instance",
|
||||
Type: mysql.SourceType,
|
||||
Kind: mysql.SourceKind,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
@@ -120,11 +120,15 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got, cmpopts.EquateEmpty()); diff != "" {
|
||||
if diff := cmp.Diff(tc.want, got.Sources, cmpopts.EquateEmpty()); diff != "" {
|
||||
t.Fatalf("mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
@@ -141,51 +145,55 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mysql-instance
|
||||
type: mysql
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
sources:
|
||||
my-mysql-instance:
|
||||
kind: mysql
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-mysql-instance\" as \"mysql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | name: my-mysql-instance\n 5 | password: my_pass\n 6 | ",
|
||||
err: "unknown field \"foo\"",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mysql-instance
|
||||
type: mysql
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-mysql-instance:
|
||||
kind: mysql
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-mysql-instance\" as \"mysql\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
|
||||
err: "Field validation for 'Host' failed",
|
||||
},
|
||||
{
|
||||
desc: "invalid query params type",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-mysql-instance
|
||||
type: mysql
|
||||
host: 0.0.0.0
|
||||
port: 3306
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
queryParams: not-a-map
|
||||
sources:
|
||||
my-mysql-instance:
|
||||
kind: mysql
|
||||
host: 0.0.0.0
|
||||
port: 3306
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
queryParams: not-a-map
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-mysql-instance\" as \"mysql\": [6:14] string was used where mapping is expected\n 3 | name: my-mysql-instance\n 4 | password: my_pass\n 5 | port: 3306\n> 6 | queryParams: not-a-map\n ^\n 7 | type: mysql\n 8 | user: my_user",
|
||||
err: "string was used where mapping is expected",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
@@ -203,7 +211,7 @@ func TestFailInitialization(t *testing.T) {
|
||||
|
||||
cfg := mysql.Config{
|
||||
Name: "instance",
|
||||
Type: "mysql",
|
||||
Kind: "mysql",
|
||||
Host: "localhost",
|
||||
Port: "3306",
|
||||
Database: "db",
|
||||
|
||||
@@ -29,7 +29,7 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "neo4j"
|
||||
const SourceKind string = "neo4j"
|
||||
|
||||
var sourceClassifier *classifier.QueryClassifier = classifier.NewQueryClassifier()
|
||||
|
||||
@@ -37,8 +37,8 @@ var sourceClassifier *classifier.QueryClassifier = classifier.NewQueryClassifier
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,15 +52,15 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Uri string `yaml:"uri" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
Password string `yaml:"password" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -91,8 +91,8 @@ type Source struct {
|
||||
Driver neo4j.DriverWithContext
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -182,7 +182,7 @@ func addPlanChildren(p neo4j.Plan) []map[string]any {
|
||||
|
||||
func initNeo4jDriver(ctx context.Context, tracer trace.Tracer, uri, user, password, name string) (neo4j.DriverWithContext, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
auth := neo4j.BasicAuth(user, password, "")
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
package neo4j_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/neo4j"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -34,18 +33,18 @@ func TestParseFromYamlNeo4j(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-neo4j-instance
|
||||
type: neo4j
|
||||
uri: neo4j+s://my-host:7687
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-neo4j-instance:
|
||||
kind: neo4j
|
||||
uri: neo4j+s://my-host:7687
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-neo4j-instance": neo4j.Config{
|
||||
Name: "my-neo4j-instance",
|
||||
Type: neo4j.SourceType,
|
||||
Kind: neo4j.SourceKind,
|
||||
Uri: "neo4j+s://my-host:7687",
|
||||
Database: "my_db",
|
||||
User: "my_user",
|
||||
@@ -56,12 +55,16 @@ func TestParseFromYamlNeo4j(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -77,33 +80,37 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-neo4j-instance
|
||||
type: neo4j
|
||||
uri: neo4j+s://my-host:7687
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
sources:
|
||||
my-neo4j-instance:
|
||||
kind: neo4j
|
||||
uri: neo4j+s://my-host:7687
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-neo4j-instance\" as \"neo4j\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | name: my-neo4j-instance\n 4 | password: my_pass\n 5 | type: neo4j\n 6 | ",
|
||||
err: "unable to parse source \"my-neo4j-instance\" as \"neo4j\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | kind: neo4j\n 4 | password: my_pass\n 5 | uri: neo4j+s://my-host:7687\n 6 | ",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-neo4j-instance
|
||||
type: neo4j
|
||||
uri: neo4j+s://my-host:7687
|
||||
database: my_db
|
||||
user: my_user
|
||||
sources:
|
||||
my-neo4j-instance:
|
||||
kind: neo4j
|
||||
uri: neo4j+s://my-host:7687
|
||||
database: my_db
|
||||
user: my_user
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-neo4j-instance\" as \"neo4j\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-neo4j-instance\" as \"neo4j\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -27,14 +27,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "oceanbase"
|
||||
const SourceKind string = "oceanbase"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
@@ -57,8 +57,8 @@ type Config struct {
|
||||
QueryTimeout string `yaml:"queryTimeout"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -86,8 +86,8 @@ type Source struct {
|
||||
Pool *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -153,7 +153,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
|
||||
}
|
||||
|
||||
func initOceanBaseConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
|
||||
_, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
_, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", user, pass, host, port, dbname)
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
package oceanbase_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/oceanbase"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -35,19 +34,19 @@ func TestParseFromYamlOceanBase(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-oceanbase-instance
|
||||
type: oceanbase
|
||||
host: 0.0.0.0
|
||||
port: 2881
|
||||
database: ob_db
|
||||
user: ob_user
|
||||
password: ob_pass
|
||||
sources:
|
||||
my-oceanbase-instance:
|
||||
kind: oceanbase
|
||||
host: 0.0.0.0
|
||||
port: 2881
|
||||
database: ob_db
|
||||
user: ob_user
|
||||
password: ob_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-oceanbase-instance": oceanbase.Config{
|
||||
Name: "my-oceanbase-instance",
|
||||
Type: oceanbase.SourceType,
|
||||
Kind: oceanbase.SourceKind,
|
||||
Host: "0.0.0.0",
|
||||
Port: "2881",
|
||||
Database: "ob_db",
|
||||
@@ -59,20 +58,20 @@ func TestParseFromYamlOceanBase(t *testing.T) {
|
||||
{
|
||||
desc: "with query timeout",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-oceanbase-instance
|
||||
type: oceanbase
|
||||
host: 0.0.0.0
|
||||
port: 2881
|
||||
database: ob_db
|
||||
user: ob_user
|
||||
password: ob_pass
|
||||
queryTimeout: 30s
|
||||
sources:
|
||||
my-oceanbase-instance:
|
||||
kind: oceanbase
|
||||
host: 0.0.0.0
|
||||
port: 2881
|
||||
database: ob_db
|
||||
user: ob_user
|
||||
password: ob_pass
|
||||
queryTimeout: 30s
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-oceanbase-instance": oceanbase.Config{
|
||||
Name: "my-oceanbase-instance",
|
||||
Type: oceanbase.SourceType,
|
||||
Kind: oceanbase.SourceKind,
|
||||
Host: "0.0.0.0",
|
||||
Port: "2881",
|
||||
Database: "ob_db",
|
||||
@@ -85,12 +84,16 @@ func TestParseFromYamlOceanBase(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -106,35 +109,39 @@ func TestFailParseFromYamlOceanBase(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-oceanbase-instance
|
||||
type: oceanbase
|
||||
host: 0.0.0.0
|
||||
port: 2881
|
||||
database: ob_db
|
||||
user: ob_user
|
||||
password: ob_pass
|
||||
foo: bar
|
||||
sources:
|
||||
my-oceanbase-instance:
|
||||
kind: oceanbase
|
||||
host: 0.0.0.0
|
||||
port: 2881
|
||||
database: ob_db
|
||||
user: ob_user
|
||||
password: ob_pass
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-oceanbase-instance\" as \"oceanbase\": [2:1] unknown field \"foo\"\n 1 | database: ob_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | name: my-oceanbase-instance\n 5 | password: ob_pass\n 6 | ",
|
||||
err: "unable to parse source \"my-oceanbase-instance\" as \"oceanbase\": [2:1] unknown field \"foo\"\n 1 | database: ob_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | kind: oceanbase\n 5 | password: ob_pass\n 6 | ",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-oceanbase-instance
|
||||
type: oceanbase
|
||||
port: 2881
|
||||
database: ob_db
|
||||
user: ob_user
|
||||
password: ob_pass
|
||||
sources:
|
||||
my-oceanbase-instance:
|
||||
kind: oceanbase
|
||||
port: 2881
|
||||
database: ob_db
|
||||
user: ob_user
|
||||
password: ob_pass
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-oceanbase-instance\" as \"oceanbase\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-oceanbase-instance\" as \"oceanbase\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -18,14 +18,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "oracle"
|
||||
const SourceKind string = "oracle"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
ConnectionString string `yaml:"connectionString,omitempty"`
|
||||
TnsAlias string `yaml:"tnsAlias,omitempty"`
|
||||
TnsAdmin string `yaml:"tnsAdmin,omitempty"`
|
||||
@@ -95,8 +95,8 @@ func (c Config) validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -124,8 +124,8 @@ type Source struct {
|
||||
DB *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -239,7 +239,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
|
||||
|
||||
func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Config) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, config.Name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, config.Name)
|
||||
defer span.End()
|
||||
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
|
||||
@@ -3,13 +3,12 @@
|
||||
package oracle_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/oracle"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -23,18 +22,18 @@ func TestParseFromYamlOracle(t *testing.T) {
|
||||
{
|
||||
desc: "connection string and useOCI=true",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-oracle-cs
|
||||
type: oracle
|
||||
connectionString: "my-host:1521/XEPDB1"
|
||||
user: my_user
|
||||
password: my_pass
|
||||
useOCI: true
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
sources:
|
||||
my-oracle-cs:
|
||||
kind: oracle
|
||||
connectionString: "my-host:1521/XEPDB1"
|
||||
user: my_user
|
||||
password: my_pass
|
||||
useOCI: true
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-oracle-cs": oracle.Config{
|
||||
Name: "my-oracle-cs",
|
||||
Type: oracle.SourceType,
|
||||
Kind: oracle.SourceKind,
|
||||
ConnectionString: "my-host:1521/XEPDB1",
|
||||
User: "my_user",
|
||||
Password: "my_pass",
|
||||
@@ -45,19 +44,19 @@ func TestParseFromYamlOracle(t *testing.T) {
|
||||
{
|
||||
desc: "host/port/serviceName and default useOCI=false",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-oracle-host
|
||||
type: oracle
|
||||
host: my-host
|
||||
port: 1521
|
||||
serviceName: ORCLPDB
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
sources:
|
||||
my-oracle-host:
|
||||
kind: oracle
|
||||
host: my-host
|
||||
port: 1521
|
||||
serviceName: ORCLPDB
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-oracle-host": oracle.Config{
|
||||
Name: "my-oracle-host",
|
||||
Type: oracle.SourceType,
|
||||
Kind: oracle.SourceKind,
|
||||
Host: "my-host",
|
||||
Port: 1521,
|
||||
ServiceName: "ORCLPDB",
|
||||
@@ -70,19 +69,19 @@ func TestParseFromYamlOracle(t *testing.T) {
|
||||
{
|
||||
desc: "tnsAlias and TnsAdmin specified with explicit useOCI=true",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-oracle-tns-oci
|
||||
type: oracle
|
||||
tnsAlias: FINANCE_DB
|
||||
tnsAdmin: /opt/oracle/network/admin
|
||||
user: my_user
|
||||
password: my_pass
|
||||
useOCI: true
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
sources:
|
||||
my-oracle-tns-oci:
|
||||
kind: oracle
|
||||
tnsAlias: FINANCE_DB
|
||||
tnsAdmin: /opt/oracle/network/admin
|
||||
user: my_user
|
||||
password: my_pass
|
||||
useOCI: true
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-oracle-tns-oci": oracle.Config{
|
||||
Name: "my-oracle-tns-oci",
|
||||
Type: oracle.SourceType,
|
||||
Kind: oracle.SourceKind,
|
||||
TnsAlias: "FINANCE_DB",
|
||||
TnsAdmin: "/opt/oracle/network/admin",
|
||||
User: "my_user",
|
||||
@@ -94,18 +93,22 @@ func TestParseFromYamlOracle(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse:\nwant: %v\ngot: %v\ndiff: %s", tc.want, got, cmp.Diff(tc.want, got))
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse:\nwant: %v\ngot: %v\ndiff: %s", tc.want, got.Sources, cmp.Diff(tc.want, got.Sources))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailParseFromYaml(t *testing.T) {
|
||||
func TestFailParseFromYamlOracle(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
@@ -114,72 +117,76 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-oracle-instance
|
||||
type: oracle
|
||||
host: my-host
|
||||
serviceName: ORCL
|
||||
user: my_user
|
||||
password: my_pass
|
||||
extraField: value
|
||||
sources:
|
||||
my-oracle-instance:
|
||||
kind: oracle
|
||||
host: my-host
|
||||
serviceName: ORCL
|
||||
user: my_user
|
||||
password: my_pass
|
||||
extraField: value
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-oracle-instance\" as \"oracle\": [1:1] unknown field \"extraField\"\n> 1 | extraField: value\n ^\n 2 | host: my-host\n 3 | name: my-oracle-instance\n 4 | password: my_pass\n 5 | ",
|
||||
err: "unable to parse source \"my-oracle-instance\" as \"oracle\": [1:1] unknown field \"extraField\"\n> 1 | extraField: value\n ^\n 2 | host: my-host\n 3 | kind: oracle\n 4 | password: my_pass\n 5 | ",
|
||||
},
|
||||
{
|
||||
desc: "missing required password field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-oracle-instance
|
||||
type: oracle
|
||||
host: my-host
|
||||
serviceName: ORCL
|
||||
user: my_user
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-oracle-instance\" as \"oracle\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
|
||||
sources:
|
||||
my-oracle-instance:
|
||||
kind: oracle
|
||||
host: my-host
|
||||
serviceName: ORCL
|
||||
user: my_user
|
||||
`,
|
||||
err: "unable to parse source \"my-oracle-instance\" as \"oracle\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
|
||||
},
|
||||
{
|
||||
desc: "missing connection method fields (validate fails)",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-oracle-instance
|
||||
type: oracle
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: must provide one of: 'tns_alias', 'connection_string', or both 'host' and 'service_name'",
|
||||
sources:
|
||||
my-oracle-instance:
|
||||
kind: oracle
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
err: "unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: must provide one of: 'tns_alias', 'connection_string', or both 'host' and 'service_name'",
|
||||
},
|
||||
{
|
||||
desc: "multiple connection methods provided (validate fails)",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-oracle-instance
|
||||
type: oracle
|
||||
host: my-host
|
||||
serviceName: ORCL
|
||||
connectionString: "my-host:1521/XEPDB1"
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: provide only one connection method: 'tns_alias', 'connection_string', or 'host'+'service_name'",
|
||||
sources:
|
||||
my-oracle-instance:
|
||||
kind: oracle
|
||||
host: my-host
|
||||
serviceName: ORCL
|
||||
connectionString: "my-host:1521/XEPDB1"
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
err: "unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: provide only one connection method: 'tns_alias', 'connection_string', or 'host'+'service_name'",
|
||||
},
|
||||
{
|
||||
desc: "fail on tnsAdmin with useOCI=false",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-oracle-fail
|
||||
type: oracle
|
||||
tnsAlias: FINANCE_DB
|
||||
tnsAdmin: /opt/oracle/network/admin
|
||||
user: my_user
|
||||
password: my_pass
|
||||
useOCI: false
|
||||
sources:
|
||||
my-oracle-fail:
|
||||
kind: oracle
|
||||
tnsAlias: FINANCE_DB
|
||||
tnsAdmin: /opt/oracle/network/admin
|
||||
user: my_user
|
||||
password: my_pass
|
||||
useOCI: false
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-oracle-fail\" as \"oracle\": invalid Oracle configuration: `tnsAdmin` can only be used when `UseOCI` is true, or use `walletLocation` instead",
|
||||
err: "unable to parse source \"my-oracle-fail\" as \"oracle\": invalid Oracle configuration: `tnsAdmin` can only be used when `UseOCI` is true, or use `walletLocation` instead",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -28,14 +28,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "postgres"
|
||||
const SourceKind string = "postgres"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
@@ -58,8 +58,8 @@ type Config struct {
|
||||
QueryParams map[string]string `yaml:"queryParams"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -87,8 +87,8 @@ type Source struct {
|
||||
Pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -128,7 +128,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
|
||||
|
||||
func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string) (*pgxpool.Pool, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright 2025 Google LLC
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -15,14 +15,13 @@
|
||||
package postgres_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/postgres"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -36,19 +35,19 @@ func TestParseFromYamlPostgres(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: postgres
|
||||
host: my-host
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: postgres
|
||||
host: my-host
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-pg-instance": postgres.Config{
|
||||
Name: "my-pg-instance",
|
||||
Type: postgres.SourceType,
|
||||
Kind: postgres.SourceKind,
|
||||
Host: "my-host",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
@@ -60,22 +59,22 @@ func TestParseFromYamlPostgres(t *testing.T) {
|
||||
{
|
||||
desc: "example with query params",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: postgres
|
||||
host: my-host
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
queryParams:
|
||||
sslmode: verify-full
|
||||
sslrootcert: /tmp/ca.crt
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: postgres
|
||||
host: my-host
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
queryParams:
|
||||
sslmode: verify-full
|
||||
sslrootcert: /tmp/ca.crt
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-pg-instance": postgres.Config{
|
||||
Name: "my-pg-instance",
|
||||
Type: postgres.SourceType,
|
||||
Kind: postgres.SourceKind,
|
||||
Host: "my-host",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
@@ -91,12 +90,16 @@ func TestParseFromYamlPostgres(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -112,35 +115,39 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: postgres
|
||||
host: my-host
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: postgres
|
||||
host: my-host
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"postgres\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: my-host\n 4 | name: my-pg-instance\n 5 | password: my_pass\n 6 | ",
|
||||
err: "unable to parse source \"my-pg-instance\" as \"postgres\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: my-host\n 4 | kind: postgres\n 5 | password: my_pass\n 6 | ",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: postgres
|
||||
host: my-host
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: postgres
|
||||
host: my-host
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"postgres\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-pg-instance\" as \"postgres\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -24,14 +24,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "redis"
|
||||
const SourceKind string = "redis"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Address []string `yaml:"address" validate:"required"`
|
||||
Username string `yaml:"username"`
|
||||
Password string `yaml:"password"`
|
||||
@@ -54,8 +54,8 @@ type Config struct {
|
||||
ClusterEnabled bool `yaml:"clusterEnabled"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
// RedisClient is an interface for `redis.Client` and `redis.ClusterClient
|
||||
@@ -141,8 +141,8 @@ type Source struct {
|
||||
Client RedisClient
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
|
||||
@@ -15,13 +15,12 @@
|
||||
package redis_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/redis"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -35,16 +34,16 @@ func TestParseFromYamlRedis(t *testing.T) {
|
||||
{
|
||||
desc: "default setting",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-redis-instance
|
||||
type: redis
|
||||
address:
|
||||
- 127.0.0.1
|
||||
sources:
|
||||
my-redis-instance:
|
||||
kind: redis
|
||||
address:
|
||||
- 127.0.0.1
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-redis-instance": redis.Config{
|
||||
Name: "my-redis-instance",
|
||||
Type: redis.SourceType,
|
||||
Kind: redis.SourceKind,
|
||||
Address: []string{"127.0.0.1"},
|
||||
ClusterEnabled: false,
|
||||
UseGCPIAM: false,
|
||||
@@ -54,20 +53,20 @@ func TestParseFromYamlRedis(t *testing.T) {
|
||||
{
|
||||
desc: "advanced example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-redis-instance
|
||||
type: redis
|
||||
address:
|
||||
- 127.0.0.1
|
||||
password: my-pass
|
||||
database: 1
|
||||
useGCPIAM: true
|
||||
clusterEnabled: true
|
||||
sources:
|
||||
my-redis-instance:
|
||||
kind: redis
|
||||
address:
|
||||
- 127.0.0.1
|
||||
password: my-pass
|
||||
database: 1
|
||||
useGCPIAM: true
|
||||
clusterEnabled: true
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-redis-instance": redis.Config{
|
||||
Name: "my-redis-instance",
|
||||
Type: redis.SourceType,
|
||||
Kind: redis.SourceKind,
|
||||
Address: []string{"127.0.0.1"},
|
||||
Password: "my-pass",
|
||||
Database: 1,
|
||||
@@ -79,12 +78,16 @@ func TestParseFromYamlRedis(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -100,43 +103,48 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "invalid database",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-redis-instance
|
||||
type: redis
|
||||
address:
|
||||
- 127.0.0.1
|
||||
password: my-pass
|
||||
database: data
|
||||
sources:
|
||||
my-redis-instance:
|
||||
kind: redis
|
||||
project: my-project
|
||||
address:
|
||||
- 127.0.0.1
|
||||
password: my-pass
|
||||
database: data
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-redis-instance\" as \"redis\": [3:11] cannot unmarshal string into Go struct field Config.Database of type int\n 1 | address:\n 2 | - 127.0.0.1\n> 3 | database: data\n ^\n 4 | name: my-redis-instance\n 5 | password: my-pass\n 6 | type: redis",
|
||||
err: "cannot unmarshal string into Go struct field .Sources of type int",
|
||||
},
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-redis-instance
|
||||
type: redis
|
||||
project: my-project
|
||||
address:
|
||||
- 127.0.0.1
|
||||
password: my-pass
|
||||
database: 1
|
||||
sources:
|
||||
my-redis-instance:
|
||||
kind: redis
|
||||
project: my-project
|
||||
address:
|
||||
- 127.0.0.1
|
||||
password: my-pass
|
||||
database: 1
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-redis-instance\" as \"redis\": [6:1] unknown field \"project\"\n 3 | database: 1\n 4 | name: my-redis-instance\n 5 | password: my-pass\n> 6 | project: my-project\n ^\n 7 | type: redis",
|
||||
err: "unable to parse source \"my-redis-instance\" as \"redis\": [6:1] unknown field \"project\"",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-redis-instance
|
||||
type: redis
|
||||
sources:
|
||||
my-redis-instance:
|
||||
kind: redis
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-redis-instance\" as \"redis\": Key: 'Config.Address' Error:Field validation for 'Address' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-redis-instance\" as \"redis\": Key: 'Config.Address' Error:Field validation for 'Address' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -33,14 +33,14 @@ import (
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
const SourceType string = "serverless-spark"
|
||||
const SourceKind string = "serverless-spark"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,13 +54,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Location string `yaml:"location" validate:"required"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -94,8 +94,8 @@ type Source struct {
|
||||
OpsClient *longrunning.OperationsClient
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
package serverlessspark_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -34,16 +33,16 @@ func TestParseFromYamlServerlessSpark(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: serverless-spark
|
||||
project: my-project
|
||||
location: my-location
|
||||
sources:
|
||||
my-instance:
|
||||
kind: serverless-spark
|
||||
project: my-project
|
||||
location: my-location
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": serverlessspark.Config{
|
||||
Name: "my-instance",
|
||||
Type: serverlessspark.SourceType,
|
||||
Kind: serverlessspark.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "my-location",
|
||||
},
|
||||
@@ -52,12 +51,16 @@ func TestParseFromYamlServerlessSpark(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -73,39 +76,43 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: serverless-spark
|
||||
project: my-project
|
||||
location: my-location
|
||||
foo: bar
|
||||
sources:
|
||||
my-instance:
|
||||
kind: serverless-spark
|
||||
project: my-project
|
||||
location: my-location
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"serverless-spark\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | location: my-location\n 3 | name: my-instance\n 4 | project: my-project\n 5 | ",
|
||||
err: "unable to parse source \"my-instance\" as \"serverless-spark\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | kind: serverless-spark\n 3 | location: my-location\n 4 | project: my-project",
|
||||
},
|
||||
{
|
||||
desc: "missing required field project",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: serverless-spark
|
||||
location: my-location
|
||||
sources:
|
||||
my-instance:
|
||||
kind: serverless-spark
|
||||
location: my-location
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"serverless-spark\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-instance\" as \"serverless-spark\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
|
||||
},
|
||||
{
|
||||
desc: "missing required field location",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-instance
|
||||
type: serverless-spark
|
||||
project: my-project
|
||||
sources:
|
||||
my-instance:
|
||||
kind: serverless-spark
|
||||
project: my-project
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"serverless-spark\": Key: 'Config.Location' Error:Field validation for 'Location' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-instance\" as \"serverless-spark\": Key: 'Config.Location' Error:Field validation for 'Location' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -29,15 +29,15 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
// SourceType for SingleStore source
|
||||
const SourceType string = "singlestore"
|
||||
// SourceKind for SingleStore source
|
||||
const SourceKind string = "singlestore"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
// Config holds the configuration parameters for connecting to a SingleStore database.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
@@ -61,9 +61,9 @@ type Config struct {
|
||||
QueryTimeout string `yaml:"queryTimeout"`
|
||||
}
|
||||
|
||||
// SourceConfigType returns the type of the source configuration.
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
// SourceConfigKind returns the kind of the source configuration.
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
// Initialize sets up the SingleStore connection pool and returns a Source.
|
||||
@@ -93,9 +93,9 @@ type Source struct {
|
||||
Pool *sql.DB
|
||||
}
|
||||
|
||||
// SourceType returns the type of the source configuration.
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
// SourceKind returns the kind of the source configuration.
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -162,7 +162,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
|
||||
|
||||
func initSingleStoreConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
// Configure the driver to connect to the database
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
package singlestore_test
|
||||
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -12,15 +14,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package singlestore_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/singlestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -34,19 +33,19 @@ func TestParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-s2-instance
|
||||
type: singlestore
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-s2-instance:
|
||||
kind: singlestore
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-s2-instance": singlestore.Config{
|
||||
Name: "my-s2-instance",
|
||||
Type: singlestore.SourceType,
|
||||
Kind: singlestore.SourceKind,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
@@ -58,20 +57,20 @@ func TestParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "with query timeout",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-s2-instance
|
||||
type: singlestore
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
queryTimeout: 45s
|
||||
sources:
|
||||
my-s2-instance:
|
||||
kind: singlestore
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
queryTimeout: 45s
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-s2-instance": singlestore.Config{
|
||||
Name: "my-s2-instance",
|
||||
Type: singlestore.SourceType,
|
||||
Kind: singlestore.SourceKind,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
@@ -84,12 +83,16 @@ func TestParseFromYaml(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -105,35 +108,39 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-s2-instance
|
||||
type: singlestore
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
sources:
|
||||
my-s2-instance:
|
||||
kind: singlestore
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-s2-instance\" as \"singlestore\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | name: my-s2-instance\n 5 | password: my_pass\n 6 | ",
|
||||
err: "unable to parse source \"my-s2-instance\" as \"singlestore\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | kind: singlestore\n 5 | password: my_pass\n 6 | ",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-s2-instance
|
||||
type: singlestore
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
sources:
|
||||
my-s2-instance:
|
||||
kind: singlestore
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-s2-instance\" as \"singlestore\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-s2-instance\" as \"singlestore\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -25,14 +25,14 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "snowflake"
|
||||
const SourceKind string = "snowflake"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Account string `yaml:"account" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
Password string `yaml:"password" validate:"required"`
|
||||
@@ -56,8 +56,8 @@ type Config struct {
|
||||
Role string `yaml:"role"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -85,8 +85,8 @@ type Source struct {
|
||||
DB *sqlx.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -137,7 +137,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
|
||||
|
||||
func initSnowflakeConnection(ctx context.Context, tracer trace.Tracer, name, account, user, password, database, schema, warehouse, role string) (*sqlx.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
// Set defaults for optional parameters
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
package snowflake_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/snowflake"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
@@ -34,19 +33,19 @@ func TestParseFromYamlSnowflake(t *testing.T) {
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-snowflake-instance
|
||||
type: snowflake
|
||||
account: my-account
|
||||
user: my_user
|
||||
password: my_pass
|
||||
database: my_db
|
||||
schema: my_schema
|
||||
sources:
|
||||
my-snowflake-instance:
|
||||
kind: snowflake
|
||||
account: my-account
|
||||
user: my_user
|
||||
password: my_pass
|
||||
database: my_db
|
||||
schema: my_schema
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-snowflake-instance": snowflake.Config{
|
||||
Name: "my-snowflake-instance",
|
||||
Type: snowflake.SourceType,
|
||||
Kind: snowflake.SourceKind,
|
||||
Account: "my-account",
|
||||
User: "my_user",
|
||||
Password: "my_pass",
|
||||
@@ -60,12 +59,16 @@ func TestParseFromYamlSnowflake(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -81,35 +84,39 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-snowflake-instance
|
||||
type: snowflake
|
||||
account: my-account
|
||||
user: my_user
|
||||
password: my_pass
|
||||
database: my_db
|
||||
schema: my_schema
|
||||
foo: bar
|
||||
sources:
|
||||
my-snowflake-instance:
|
||||
kind: snowflake
|
||||
account: my-account
|
||||
user: my_user
|
||||
password: my_pass
|
||||
database: my_db
|
||||
schema: my_schema
|
||||
foo: bar
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-snowflake-instance\" as \"snowflake\": [3:1] unknown field \"foo\"\n 1 | account: my-account\n 2 | database: my_db\n> 3 | foo: bar\n ^\n 4 | name: my-snowflake-instance\n 5 | password: my_pass\n 6 | schema: my_schema\n 7 | ",
|
||||
err: "unable to parse source \"my-snowflake-instance\" as \"snowflake\": [3:1] unknown field \"foo\"\n 1 | account: my-account\n 2 | database: my_db\n> 3 | foo: bar\n ^\n 4 | kind: snowflake\n 5 | password: my_pass\n 6 | schema: my_schema\n 7 | ",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-snowflake-instance
|
||||
type: snowflake
|
||||
account: my-account
|
||||
user: my_user
|
||||
password: my_pass
|
||||
database: my_db
|
||||
sources:
|
||||
my-snowflake-instance:
|
||||
kind: snowflake
|
||||
account: my-account
|
||||
user: my_user
|
||||
password: my_pass
|
||||
database: my_db
|
||||
`,
|
||||
err: "error unmarshaling sources: unable to parse source \"my-snowflake-instance\" as \"snowflake\": Key: 'Config.Schema' Error:Field validation for 'Schema' failed on the 'required' tag",
|
||||
err: "unable to parse source \"my-snowflake-instance\" as \"snowflake\": Key: 'Config.Schema' Error:Field validation for 'Schema' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
|
||||
@@ -29,48 +29,48 @@ type SourceConfigFactory func(ctx context.Context, name string, decoder *yaml.De
|
||||
|
||||
var sourceRegistry = make(map[string]SourceConfigFactory)
|
||||
|
||||
// Register registers a new source type with its factory.
|
||||
// It returns false if the type is already registered.
|
||||
func Register(sourceType string, factory SourceConfigFactory) bool {
|
||||
if _, exists := sourceRegistry[sourceType]; exists {
|
||||
// Source with this type already exists, do not overwrite.
|
||||
// Register registers a new source kind with its factory.
|
||||
// It returns false if the kind is already registered.
|
||||
func Register(kind string, factory SourceConfigFactory) bool {
|
||||
if _, exists := sourceRegistry[kind]; exists {
|
||||
// Source with this kind already exists, do not overwrite.
|
||||
return false
|
||||
}
|
||||
sourceRegistry[sourceType] = factory
|
||||
sourceRegistry[kind] = factory
|
||||
return true
|
||||
}
|
||||
|
||||
// DecodeConfig decodes a source configuration using the registered factory for the given type.
|
||||
func DecodeConfig(ctx context.Context, sourceType string, name string, decoder *yaml.Decoder) (SourceConfig, error) {
|
||||
factory, found := sourceRegistry[sourceType]
|
||||
// DecodeConfig decodes a source configuration using the registered factory for the given kind.
|
||||
func DecodeConfig(ctx context.Context, kind string, name string, decoder *yaml.Decoder) (SourceConfig, error) {
|
||||
factory, found := sourceRegistry[kind]
|
||||
if !found {
|
||||
return nil, fmt.Errorf("unknown source type: %q", sourceType)
|
||||
return nil, fmt.Errorf("unknown source kind: %q", kind)
|
||||
}
|
||||
sourceConfig, err := factory(ctx, name, decoder)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse source %q as %q: %w", name, sourceType, err)
|
||||
return nil, fmt.Errorf("unable to parse source %q as %q: %w", name, kind, err)
|
||||
}
|
||||
return sourceConfig, err
|
||||
}
|
||||
|
||||
// SourceConfig is the interface for configuring a source.
|
||||
type SourceConfig interface {
|
||||
SourceConfigType() string
|
||||
SourceConfigKind() string
|
||||
Initialize(ctx context.Context, tracer trace.Tracer) (Source, error)
|
||||
}
|
||||
|
||||
// Source is the interface for the source itself.
|
||||
type Source interface {
|
||||
SourceType() string
|
||||
SourceKind() string
|
||||
ToConfig() SourceConfig
|
||||
}
|
||||
|
||||
// InitConnectionSpan adds a span for database pool connection initialization
|
||||
func InitConnectionSpan(ctx context.Context, tracer trace.Tracer, sourceType, sourceName string) (context.Context, trace.Span) {
|
||||
func InitConnectionSpan(ctx context.Context, tracer trace.Tracer, sourceKind, sourceName string) (context.Context, trace.Span) {
|
||||
ctx, span := tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/source/connect",
|
||||
trace.WithAttributes(attribute.String("source_type", sourceType)),
|
||||
trace.WithAttributes(attribute.String("source_kind", sourceKind)),
|
||||
trace.WithAttributes(attribute.String("source_name", sourceName)),
|
||||
)
|
||||
return ctx, span
|
||||
|
||||
@@ -28,14 +28,14 @@ import (
|
||||
"google.golang.org/api/iterator"
|
||||
)
|
||||
|
||||
const SourceType string = "spanner"
|
||||
const SourceKind string = "spanner"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,15 +49,15 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Instance string `yaml:"instance" validate:"required"`
|
||||
Dialect sources.Dialect `yaml:"dialect" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
@@ -80,8 +80,8 @@ type Source struct {
|
||||
Client *spanner.Client
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
@@ -171,7 +171,7 @@ func (s *Source) RunSQL(ctx context.Context, readOnly bool, statement string, pa
|
||||
|
||||
func initSpannerClient(ctx context.Context, tracer trace.Tracer, name, project, instance, dbname string) (*spanner.Client, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
// Configure the connection to the database
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user