Compare commits

..

7 Commits

532 changed files with 4591 additions and 5350 deletions

View File

@@ -59,13 +59,6 @@ You can manually trigger the bot by commenting on your Pull Request:
* `/gemini summary`: Posts a summary of the changes in the pull request.
* `/gemini help`: Overview of the available commands
## Guidelines for Pull Requests
1. Please keep your PR small for more thorough review and easier updates. In case of regression, it also allows us to roll back a single feature instead of multiple ones.
1. For non-trivial changes, consider opening an issue and discussing it with the code owners first.
1. Provide a good PR description as a record of what change is being made and why it was made. Link to a GitHub issue if it exists.
1. Make sure your code is thoroughly tested with unit tests and integration tests. Remember to clean up the test instances properly in your code to avoid memory leaks.
## Adding a New Database Source or Tool
Please create an
@@ -92,11 +85,11 @@ implementation](https://github.com/googleapis/genai-toolbox/blob/main/internal/s
`newdb.go`. Create a `Config` struct to include all the necessary parameters
for connecting to the database (e.g., host, port, username, password, database
name) and a `Source` struct to store necessary parameters for tools (e.g.,
Name, Type, connection object, additional config).
Name, Kind, connection object, additional config).
* **Implement the
[`SourceConfig`](https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/internal/sources/sources.go#L57)
interface**. This interface requires two methods:
* `SourceConfigType() string`: Returns a unique string identifier for your
* `SourceConfigKind() string`: Returns a unique string identifier for your
data source (e.g., `"newdb"`).
* `Initialize(ctx context.Context, tracer trace.Tracer) (Source, error)`:
Creates a new instance of your data source and establishes a connection to
@@ -104,7 +97,7 @@ implementation](https://github.com/googleapis/genai-toolbox/blob/main/internal/s
* **Implement the
[`Source`](https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/internal/sources/sources.go#L63)
interface**. This interface requires one method:
* `SourceType() string`: Returns the same string identifier as `SourceConfigType()`.
* `SourceKind() string`: Returns the same string identifier as `SourceConfigKind()`.
* **Implement `init()`** to register the new Source.
* **Implement Unit Tests** in a file named `newdb_test.go`.
@@ -117,8 +110,6 @@ implementation](https://github.com/googleapis/genai-toolbox/blob/main/internal/s
We recommend looking at an [example tool
implementation](https://github.com/googleapis/genai-toolbox/tree/main/internal/tools/postgres/postgressql).
Remember to keep your PRs small. For example, if you are contributing a new Source, only include one or two core Tools within the same PR, the rest of the Tools can come in subsequent PRs.
* **Create a new directory** under `internal/tools` for your tool type (e.g., `internal/tools/newdb/newdbtool`).
* **Define a configuration struct** for your tool in a file named `newdbtool.go`.
Create a `Config` struct and a `Tool` struct to store necessary parameters for
@@ -126,7 +117,7 @@ tools.
* **Implement the
[`ToolConfig`](https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/internal/tools/tools.go#L61)
interface**. This interface requires one method:
* `ToolConfigType() string`: Returns a unique string identifier for your tool
* `ToolConfigKind() string`: Returns a unique string identifier for your tool
(e.g., `"newdb-tool"`).
* `Initialize(sources map[string]Source) (Tool, error)`: Creates a new
instance of your tool and validates that it can connect to the specified
@@ -172,8 +163,6 @@ tools.
parameters][temp-param-doc]. Only run this test if template
parameters apply to your tool.
* **Add additional tests** for the tools that are not covered by the predefined tests. Every tool must be tested!
* **Add the new database to the integration test workflow** in
[integration.cloudbuild.yaml](.ci/integration.cloudbuild.yaml).
@@ -243,7 +232,7 @@ resources.
| style | Update src code, with only formatting and whitespace updates (e.g. code formatter or linter changes). |
Pull requests should always add scope whenever possible. The scope is
formatted as `<scope-resource>/<scope-type>` (e.g., `sources/postgres`, or
formatted as `<scope-type>/<scope-kind>` (e.g., `sources/postgres`, or
`tools/mssql-sql`).
Ideally, **each PR covers only one scope**, if this is
@@ -255,4 +244,4 @@ resources.
* **PR Description:** PR description should **always** be included. It should
include a concise description of the changes, it's impact, along with a
summary of the solution. If the PR is related to a specific issue, the issue
number should be mentioned in the PR description (e.g. `Fixes #1`).
number should be mentioned in the PR description (e.g. `Fixes #1`).

View File

@@ -954,7 +954,7 @@ For more details on configuring different types of sources, see the
### Tools
The `tools` section of a `tools.yaml` define the actions an agent can take: what
type of tool it is, which source(s) it affects, what parameters it uses, etc.
kind of tool it is, which source(s) it affects, what parameters it uses, etc.
```yaml
tools:

View File

@@ -15,7 +15,6 @@
package cmd
import (
"bytes"
"context"
_ "embed"
"fmt"
@@ -93,7 +92,6 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudmonitoring"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcloneinstance"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatebackup"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatedatabase"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreateusers"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlgetinstances"
@@ -395,6 +393,7 @@ func NewCommand(opts ...Option) *Command {
type ToolsFile struct {
Sources server.SourceConfigs `yaml:"sources"`
AuthSources server.AuthServiceConfigs `yaml:"authSources"` // Deprecated: Kept for compatibility.
AuthServices server.AuthServiceConfigs `yaml:"authServices"`
EmbeddingModels server.EmbeddingModelConfigs `yaml:"embeddingModels"`
Tools server.ToolConfigs `yaml:"tools"`
@@ -425,106 +424,6 @@ func parseEnv(input string) (string, error) {
return output, err
}
func convertToolsFile(ctx context.Context, raw []byte) ([]byte, error) {
var input yaml.MapSlice
decoder := yaml.NewDecoder(bytes.NewReader(raw), yaml.UseOrderedMap())
if err := decoder.Decode(&input); err != nil {
return nil, err
}
// Convert raw MapSlice to a helper map for quick lookup
// while keeping the values as MapSlices to preserve internal order
resourceOrder := []string{}
lookup := make(map[string]yaml.MapSlice)
for _, item := range input {
key, ok := item.Key.(string)
if !ok {
return nil, fmt.Errorf("unexpected non-string key in input: %v", item.Key)
}
if slice, ok := item.Value.(yaml.MapSlice); ok {
// convert authSources to authServices
if key == "authSources" {
key = "authServices"
}
// works even if lookup[key] is nil
lookup[key] = append(lookup[key], slice...)
// preserving the resource's order of original toolsFile
if !slices.Contains(resourceOrder, key) {
resourceOrder = append(resourceOrder, key)
}
} else {
// toolsfile is already v2
if key == "kind" {
return raw, nil
}
return nil, fmt.Errorf("'%s' is not a map", key)
}
}
// convert to tools file v2
var buf bytes.Buffer
encoder := yaml.NewEncoder(&buf)
for _, kind := range resourceOrder {
data, exists := lookup[kind]
if !exists {
// if this is skipped for all keys, the tools file is in v2
continue
}
// Transform each entry
for _, entry := range data {
entryName, ok := entry.Key.(string)
if !ok {
return nil, fmt.Errorf("unexpected non-string key for entry in '%s': %v", kind, entry.Key)
}
entryBody := ProcessValue(entry.Value, kind == "toolsets")
transformed := yaml.MapSlice{
{Key: "kind", Value: kind},
{Key: "name", Value: entryName},
}
// Merge the transformed body into our result
if bodySlice, ok := entryBody.(yaml.MapSlice); ok {
transformed = append(transformed, bodySlice...)
} else {
return nil, fmt.Errorf("unable to convert entryBody to MapSlice")
}
if err := encoder.Encode(transformed); err != nil {
return nil, err
}
}
}
return buf.Bytes(), nil
}
// ProcessValue recursively looks for MapSlices to rename 'kind' -> 'type'
func ProcessValue(v any, isToolset bool) any {
switch val := v.(type) {
case yaml.MapSlice:
for i := range val {
// Perform renaming
if val[i].Key == "kind" {
val[i].Key = "type"
}
// Recursive call for nested values (e.g., nested objects or lists)
val[i].Value = ProcessValue(val[i].Value, false)
}
return val
case []any:
// Process lists: If it's a toolset top-level list, wrap it.
if isToolset {
return yaml.MapSlice{{Key: "tools", Value: val}}
}
// Otherwise, recurse into list items (to catch nested objects)
for i := range val {
val[i] = ProcessValue(val[i], false)
}
return val
default:
return val
}
}
// parseToolsFile parses the provided yaml into appropriate configs.
func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
var toolsFile ToolsFile
@@ -535,13 +434,8 @@ func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
}
raw = []byte(output)
raw, err = convertToolsFile(ctx, raw)
if err != nil {
return toolsFile, fmt.Errorf("error converting tools file: %s", err)
}
// Parse contents
toolsFile.Sources, toolsFile.AuthServices, toolsFile.EmbeddingModels, toolsFile.Tools, toolsFile.Toolsets, toolsFile.Prompts, err = server.UnmarshalResourceConfig(ctx, raw)
err = yaml.UnmarshalContext(ctx, raw, &toolsFile, yaml.Strict())
if err != nil {
return toolsFile, err
}
@@ -573,6 +467,18 @@ func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) {
}
}
// Check for conflicts and merge authSources (deprecated, but still support)
for name, authSource := range file.AuthSources {
if _, exists := merged.AuthSources[name]; exists {
conflicts = append(conflicts, fmt.Sprintf("authSource '%s' (file #%d)", name, fileIndex+1))
} else {
if merged.AuthSources == nil {
merged.AuthSources = make(server.AuthServiceConfigs)
}
merged.AuthSources[name] = authSource
}
}
// Check for conflicts and merge authServices
for name, authService := range file.AuthServices {
if _, exists := merged.AuthServices[name]; exists {
@@ -1048,6 +954,20 @@ func run(cmd *Command) error {
cmd.cfg.ToolsetConfigs = finalToolsFile.Toolsets
cmd.cfg.PromptConfigs = finalToolsFile.Prompts
authSourceConfigs := finalToolsFile.AuthSources
if authSourceConfigs != nil {
cmd.logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` instead")
for k, v := range authSourceConfigs {
if _, exists := cmd.cfg.AuthServiceConfigs[k]; exists {
errMsg := fmt.Errorf("resource conflict detected: authSource '%s' has the same name as an existing authService. Please rename your authSource", k)
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
cmd.cfg.AuthServiceConfigs[k] = v
}
}
instrumentation, err := telemetry.CreateTelemetryInstrumentation(versionString)
if err != nil {
errMsg := fmt.Errorf("unable to create telemetry instrumentation: %w", err)

View File

@@ -23,14 +23,12 @@ import (
"os"
"path"
"path/filepath"
"reflect"
"regexp"
"runtime"
"strings"
"testing"
"time"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/auth/google"
@@ -496,309 +494,6 @@ func TestDefaultLogLevel(t *testing.T) {
}
}
func TestConvertToolsFile(t *testing.T) {
ctx, cancelCtx := context.WithTimeout(context.Background(), time.Minute)
defer cancelCtx()
pr, pw := io.Pipe()
defer pw.Close()
defer pr.Close()
logger, err := log.NewStdLogger(pw, pw, "DEBUG")
if err != nil {
t.Fatalf("failed to setup logger %s", err)
}
ctx = util.WithLogger(ctx, logger)
tcs := []struct {
desc string
in string
want string
isErr bool
errStr string
}{
{
desc: "basic convert",
in: `
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
authServices:
my-google-auth:
kind: google
clientId: testing-id
tools:
example_tool:
kind: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
toolsets:
example_toolset:
- example_tool
prompts:
code_review:
description: ask llm to analyze code quality
messages:
- content: "please review the following code for quality: {{.code}}"
arguments:
- name: code
description: the code to review
embeddingModels:
gemini-model:
kind: gemini
model: gemini-embedding-001
apiKey: some-key
dimension: 768`,
want: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
---
kind: authServices
name: my-google-auth
type: google
clientId: testing-id
---
kind: tools
name: example_tool
type: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
---
kind: toolsets
name: example_toolset
tools:
- example_tool
---
kind: prompts
name: code_review
description: ask llm to analyze code quality
messages:
- content: "please review the following code for quality: {{.code}}"
arguments:
- name: code
description: the code to review
---
kind: embeddingModels
name: gemini-model
type: gemini
model: gemini-embedding-001
apiKey: some-key
dimension: 768`,
},
{
desc: "preserve resource order with grouping",
in: `
tools:
example_tool:
kind: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
authServices:
my-google-auth:
kind: google
clientId: testing-id
toolsets:
example_toolset:
- example_tool
authSources:
my-google-auth:
kind: google
clientId: testing-id`,
want: `
kind: tools
name: example_tool
type: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
---
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
---
kind: authServices
name: my-google-auth
type: google
clientId: testing-id
---
kind: authServices
name: my-google-auth
type: google
clientId: testing-id
---
kind: toolsets
name: example_toolset
tools:
- example_tool`,
},
{
desc: "no convertion needed",
in: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
---
kind: tools
name: example_tool
type: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
---
kind: toolsets
name: example_toolset
tools:
- example_tool`,
want: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
---
kind: tools
name: example_tool
type: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
---
kind: toolsets
name: example_toolset
tools:
- example_tool`,
},
{
desc: "invalid source",
in: `sources: invalid`,
isErr: true,
errStr: "'sources' is not a map",
},
{
desc: "invalid toolset",
in: `toolsets: invalid`,
isErr: true,
errStr: "'toolsets' is not a map",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
output, err := convertToolsFile(ctx, []byte(tc.in))
if tc.isErr {
if err == nil {
t.Fatalf("missing error: %s", tc.errStr)
}
if err.Error() != tc.errStr {
t.Fatalf("invalid error string: got %s, want %s", err, tc.errStr)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
var docs1, docs2 []yaml.MapSlice
if docs1, err = decodeToMapSlice(string(output)); err != nil {
t.Fatalf("error decoding output: %s", err)
}
if docs2, err = decodeToMapSlice(tc.want); err != nil {
t.Fatalf("Error decoding want: %s", err)
}
if !reflect.DeepEqual(docs1, docs2) {
t.Fatalf("incorrect output: got %s, want %s", string(output), tc.want)
}
})
}
}
func decodeToMapSlice(data string) ([]yaml.MapSlice, error) {
// ensures that the order is correct
var docs []yaml.MapSlice
decoder := yaml.NewDecoder(strings.NewReader(data))
for {
var doc yaml.MapSlice
err := decoder.Decode(&doc)
if err == io.EOF {
break
}
if err != nil {
return nil, err
}
docs = append(docs, doc)
}
return docs, nil
}
func TestParseToolFile(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
@@ -810,7 +505,7 @@ func TestParseToolFile(t *testing.T) {
wantToolsFile ToolsFile
}{
{
description: "basic example tools file v1",
description: "basic example",
in: `
sources:
my-pg-instance:
@@ -840,7 +535,7 @@ func TestParseToolFile(t *testing.T) {
Sources: server.SourceConfigs{
"my-pg-instance": cloudsqlpgsrc.Config{
Name: "my-pg-instance",
Type: cloudsqlpgsrc.SourceType,
Kind: cloudsqlpgsrc.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -853,7 +548,7 @@ func TestParseToolFile(t *testing.T) {
Tools: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Type: "postgres-sql",
Kind: "postgres-sql",
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
@@ -873,121 +568,7 @@ func TestParseToolFile(t *testing.T) {
},
},
{
description: "basic example tools file v2",
in: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
---
kind: authServices
name: my-google-auth
type: google
clientId: testing-id
---
kind: embeddingModels
name: gemini-model
type: gemini
model: gemini-embedding-001
apiKey: some-key
dimension: 768
---
kind: tools
name: example_tool
type: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
---
kind: toolsets
name: example_toolset
tools:
- example_tool
---
kind: prompts
name: code_review
description: ask llm to analyze code quality
messages:
- content: "please review the following code for quality: {{.code}}"
arguments:
- name: code
description: the code to review
`,
wantToolsFile: ToolsFile{
Sources: server.SourceConfigs{
"my-pg-instance": cloudsqlpgsrc.Config{
Name: "my-pg-instance",
Type: cloudsqlpgsrc.SourceType,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
IPType: "public",
Database: "my_db",
User: "my_user",
Password: "my_pass",
},
},
AuthServices: server.AuthServiceConfigs{
"my-google-auth": google.Config{
Name: "my-google-auth",
Type: google.AuthServiceType,
ClientID: "testing-id",
},
},
EmbeddingModels: server.EmbeddingModelConfigs{
"gemini-model": gemini.Config{
Name: "gemini-model",
Type: gemini.EmbeddingModelType,
Model: "gemini-embedding-001",
ApiKey: "some-key",
Dimension: 768,
},
},
Tools: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Type: "postgres-sql",
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
Parameters: []parameters.Parameter{
parameters.NewStringParameter("country", "some description"),
},
AuthRequired: []string{},
},
},
Toolsets: server.ToolsetConfigs{
"example_toolset": tools.ToolsetConfig{
Name: "example_toolset",
ToolNames: []string{"example_tool"},
},
},
Prompts: server.PromptConfigs{
"code_review": custom.Config{
Name: "code_review",
Description: "ask llm to analyze code quality",
Arguments: prompts.Arguments{
{Parameter: parameters.NewStringParameter("code", "the code to review")},
},
Messages: []prompts.Message{
{Role: "user", Content: "please review the following code for quality: {{.code}}"},
},
},
},
},
},
{
description: "only prompts",
description: "with prompts example",
in: `
prompts:
my-prompt:
@@ -1108,7 +689,7 @@ func TestParseToolFileWithAuth(t *testing.T) {
Sources: server.SourceConfigs{
"my-pg-instance": cloudsqlpgsrc.Config{
Name: "my-pg-instance",
Type: cloudsqlpgsrc.SourceType,
Kind: cloudsqlpgsrc.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -1121,19 +702,19 @@ func TestParseToolFileWithAuth(t *testing.T) {
AuthServices: server.AuthServiceConfigs{
"my-google-service": google.Config{
Name: "my-google-service",
Type: google.AuthServiceType,
Kind: google.AuthServiceKind,
ClientID: "my-client-id",
},
"other-google-service": google.Config{
Name: "other-google-service",
Type: google.AuthServiceType,
Kind: google.AuthServiceKind,
ClientID: "other-client-id",
},
},
Tools: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Type: "postgres-sql",
Kind: "postgres-sql",
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
@@ -1208,7 +789,7 @@ func TestParseToolFileWithAuth(t *testing.T) {
Sources: server.SourceConfigs{
"my-pg-instance": cloudsqlpgsrc.Config{
Name: "my-pg-instance",
Type: cloudsqlpgsrc.SourceType,
Kind: cloudsqlpgsrc.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -1218,22 +799,22 @@ func TestParseToolFileWithAuth(t *testing.T) {
Password: "my_pass",
},
},
AuthServices: server.AuthServiceConfigs{
AuthSources: server.AuthServiceConfigs{
"my-google-service": google.Config{
Name: "my-google-service",
Type: google.AuthServiceType,
Kind: google.AuthServiceKind,
ClientID: "my-client-id",
},
"other-google-service": google.Config{
Name: "other-google-service",
Type: google.AuthServiceType,
Kind: google.AuthServiceKind,
ClientID: "other-client-id",
},
},
Tools: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Type: "postgres-sql",
Kind: "postgres-sql",
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
@@ -1310,7 +891,7 @@ func TestParseToolFileWithAuth(t *testing.T) {
Sources: server.SourceConfigs{
"my-pg-instance": cloudsqlpgsrc.Config{
Name: "my-pg-instance",
Type: cloudsqlpgsrc.SourceType,
Kind: cloudsqlpgsrc.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -1323,19 +904,19 @@ func TestParseToolFileWithAuth(t *testing.T) {
AuthServices: server.AuthServiceConfigs{
"my-google-service": google.Config{
Name: "my-google-service",
Type: google.AuthServiceType,
Kind: google.AuthServiceKind,
ClientID: "my-client-id",
},
"other-google-service": google.Config{
Name: "other-google-service",
Type: google.AuthServiceType,
Kind: google.AuthServiceKind,
ClientID: "other-client-id",
},
},
Tools: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Type: "postgres-sql",
Kind: "postgres-sql",
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
@@ -1481,7 +1062,7 @@ func TestEnvVarReplacement(t *testing.T) {
Sources: server.SourceConfigs{
"my-http-instance": httpsrc.Config{
Name: "my-http-instance",
Type: httpsrc.SourceType,
Kind: httpsrc.SourceKind,
BaseURL: "http://test_server/",
Timeout: "10s",
DefaultHeaders: map[string]string{"Authorization": "ACTUAL_HEADER"},
@@ -1491,19 +1072,19 @@ func TestEnvVarReplacement(t *testing.T) {
AuthServices: server.AuthServiceConfigs{
"my-google-service": google.Config{
Name: "my-google-service",
Type: google.AuthServiceType,
Kind: google.AuthServiceKind,
ClientID: "ACTUAL_CLIENT_ID",
},
"other-google-service": google.Config{
Name: "other-google-service",
Type: google.AuthServiceType,
Kind: google.AuthServiceKind,
ClientID: "ACTUAL_CLIENT_ID_2",
},
},
Tools: server.ToolConfigs{
"example_tool": http.Config{
Name: "example_tool",
Type: "http",
Kind: "http",
Source: "my-instance",
Method: "GET",
Path: "search?name=alice&pet=cat",
@@ -1912,7 +1493,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{
"cloud_sql_postgres_admin_tools": tools.ToolsetConfig{
Name: "cloud_sql_postgres_admin_tools",
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance", "create_backup"},
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance"},
},
},
},
@@ -1922,7 +1503,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{
"cloud_sql_mysql_admin_tools": tools.ToolsetConfig{
Name: "cloud_sql_mysql_admin_tools",
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup"},
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance"},
},
},
},
@@ -1932,7 +1513,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{
"cloud_sql_mssql_admin_tools": tools.ToolsetConfig{
Name: "cloud_sql_mssql_admin_tools",
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup"},
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance"},
},
},
},

View File

View File

@@ -509,7 +509,7 @@
},
"outputs": [],
"source": [
"! pip install toolbox-core --quiet\n",
"! pip install toolbox-adk --quiet\n",
"! pip install google-adk --quiet"
]
},
@@ -525,14 +525,18 @@
"from google.adk.runners import Runner\n",
"from google.adk.sessions import InMemorySessionService\n",
"from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService\n",
"from google.adk.tools.toolbox_toolset import ToolboxToolset\n",
"from google.genai import types\n",
"from toolbox_core import ToolboxSyncClient\n",
"\n",
"import os\n",
"# TODO(developer): replace this with your Google API key\n",
"os.environ['GOOGLE_API_KEY'] = \"<GOOGLE_API_KEY>\"\n",
"\n",
"toolbox_client = ToolboxSyncClient(\"http://127.0.0.1:5000\")\n",
"# Configure toolset\n",
"toolset = ToolboxToolset(\n",
" server_url=\"http://127.0.0.1:5000\",\n",
" toolset_name=\"my-toolset\"\n",
")\n",
"\n",
"prompt = \"\"\"\n",
" You're a helpful hotel assistant. You handle hotel searching, booking and\n",
@@ -549,7 +553,7 @@
" name='hotel_agent',\n",
" description='A helpful AI assistant.',\n",
" instruction=prompt,\n",
" tools=toolbox_client.load_toolset(\"my-toolset\"),\n",
" tools=[toolset],\n",
")\n",
"\n",
"session_service = InMemorySessionService()\n",

View File

@@ -52,7 +52,7 @@ runtime](https://research.google.com/colaboratory/local-runtimes.html).
{{< tabpane persist=header >}}
{{< tab header="ADK" lang="bash" >}}
pip install toolbox-core
pip install toolbox-adk
{{< /tab >}}
{{< tab header="Langchain" lang="bash" >}}

View File

@@ -1,15 +1,17 @@
from google.adk import Agent
from google.adk.apps import App
from toolbox_core import ToolboxSyncClient
from google.adk.tools.toolbox_toolset import ToolboxToolset
# TODO(developer): update the TOOLBOX_URL to your toolbox endpoint
client = ToolboxSyncClient("http://127.0.0.1:5000")
toolset = ToolboxToolset(
server_url="http://127.0.0.1:5000",
)
root_agent = Agent(
name='root_agent',
model='gemini-2.5-flash',
instruction="You are a helpful AI assistant designed to provide accurate and useful information.",
tools=client.load_toolset(),
tools=[toolset],
)
app = App(root_agent=root_agent, name="my_agent")

View File

@@ -1,3 +1,3 @@
google-adk==1.21.0
toolbox-core==0.5.4
toolbox-adk>=0.1.0
pytest==9.0.2

View File

@@ -48,7 +48,6 @@ instance, database and users:
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
* All `viewer` tools
* `create_database`
* `create_backup`
* `roles/cloudsql.admin`: Provides full control over all resources.
* All `editor` and `viewer` tools
* `create_instance`
@@ -300,7 +299,6 @@ instances and interacting with your database:
* **create_user**: Creates a new user in a Cloud SQL instance.
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
* **clone_instance**: Creates a clone of an existing Cloud SQL for SQL Server instance.
* **create_backup**: Creates a backup on a Cloud SQL instance.
{{< notice note >}}
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs

View File

@@ -48,7 +48,6 @@ database and users:
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
* All `viewer` tools
* `create_database`
* `create_backup`
* `roles/cloudsql.admin`: Provides full control over all resources.
* All `editor` and `viewer` tools
* `create_instance`
@@ -300,7 +299,6 @@ instances and interacting with your database:
* **create_user**: Creates a new user in a Cloud SQL instance.
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
* **clone_instance**: Creates a clone of an existing Cloud SQL for MySQL instance.
* **create_backup**: Creates a backup on a Cloud SQL instance.
{{< notice note >}}
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs

View File

@@ -48,7 +48,6 @@ instance, database and users:
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
* All `viewer` tools
* `create_database`
* `create_backup`
* `roles/cloudsql.admin`: Provides full control over all resources.
* All `editor` and `viewer` tools
* `create_instance`
@@ -300,7 +299,6 @@ instances and interacting with your database:
* **create_user**: Creates a new user in a Cloud SQL instance.
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
* **clone_instance**: Creates a clone of an existing Cloud SQL for PostgreSQL instance.
* **create_backup**: Creates a backup on a Cloud SQL instance.
{{< notice note >}}
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs

View File

@@ -49,7 +49,7 @@ with the necessary configuration for deployment to Vertex AI Agent Engine.
4. Add `toolbox-core` as a dependency to the new project:
```bash
uv add toolbox-core
uv add toolbox-adk
```
## Step 3: Configure Google Cloud Authentication
@@ -95,22 +95,23 @@ authentication token.
```python
from google.adk import Agent
from google.adk.apps import App
from toolbox_core import ToolboxSyncClient, auth_methods
from google.adk.tools.toolbox_toolset import ToolboxToolset
from toolbox_adk import CredentialStrategy
# TODO(developer): Replace with your Toolbox Cloud Run Service URL
TOOLBOX_URL = "https://your-toolbox-service-xyz.a.run.app"
# Initialize the client with the Cloud Run URL and Auth headers
client = ToolboxSyncClient(
TOOLBOX_URL,
client_headers={"Authorization": auth_methods.get_google_id_token(TOOLBOX_URL)}
# Initialize the toolset with Workload Identity (generates ID token for the URL)
toolset = ToolboxToolset(
server_url=TOOLBOX_URL,
credentials=CredentialStrategy.workload_identity(target_audience=TOOLBOX_URL)
)
root_agent = Agent(
name='root_agent',
model='gemini-2.5-flash',
instruction="You are a helpful AI assistant designed to provide accurate and useful information.",
tools=client.load_toolset(),
tools=[toolset],
)
app = App(root_agent=root_agent, name="my_agent")

View File

@@ -187,7 +187,6 @@ See [Usage Examples](../reference/cli.md#examples).
manage existing resources.
* All `viewer` tools
* `create_database`
* `create_backup`
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
all resources.
* All `editor` and `viewer` tools
@@ -204,7 +203,6 @@ See [Usage Examples](../reference/cli.md#examples).
* `create_user`: Creates a new user in a Cloud SQL instance.
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
* `clone_instance`: Creates a clone for an existing Cloud SQL for MySQL instance.
* `create_backup`: Creates a backup on a Cloud SQL instance.
## Cloud SQL for PostgreSQL
@@ -277,7 +275,6 @@ See [Usage Examples](../reference/cli.md#examples).
manage existing resources.
* All `viewer` tools
* `create_database`
* `create_backup`
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
all resources.
* All `editor` and `viewer` tools
@@ -293,7 +290,6 @@ See [Usage Examples](../reference/cli.md#examples).
* `create_user`: Creates a new user in a Cloud SQL instance.
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
* `clone_instance`: Creates a clone for an existing Cloud SQL for PostgreSQL instance.
* `create_backup`: Creates a backup on a Cloud SQL instance.
## Cloud SQL for SQL Server
@@ -340,7 +336,6 @@ See [Usage Examples](../reference/cli.md#examples).
manage existing resources.
* All `viewer` tools
* `create_database`
* `create_backup`
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
all resources.
* All `editor` and `viewer` tools
@@ -356,7 +351,6 @@ See [Usage Examples](../reference/cli.md#examples).
* `create_user`: Creates a new user in a Cloud SQL instance.
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
* `clone_instance`: Creates a clone for an existing Cloud SQL for SQL Server instance.
* `create_backup`: Creates a backup on a Cloud SQL instance.
## Dataplex

View File

@@ -1,45 +0,0 @@
---
title: cloud-sql-create-backup
type: docs
weight: 10
description: "Creates a backup on a Cloud SQL instance."
---
The `cloud-sql-create-backup` tool creates an on-demand backup on a Cloud SQL instance using the Cloud SQL Admin API.
{{< notice info dd>}}
This tool uses a `source` of kind `cloud-sql-admin`.
{{< /notice >}}
## Examples
Basic backup creation (current state)
```yaml
tools:
backup-creation-basic:
kind: cloud-sql-create-backup
source: cloud-sql-admin-source
description: "Creates a backup on the given Cloud SQL instance."
```
## Reference
### Tool Configuration
| **field** | **type** | **required** | **description** |
| -------------- | :------: | :----------: | ------------------------------------------------------------- |
| kind | string | true | Must be "cloud-sql-create-backup". |
| source | string | true | The name of the `cloud-sql-admin` source to use. |
| description | string | false | A description of the tool. |
### Tool Inputs
| **parameter** | **type** | **required** | **description** |
| -------------------------- | :------: | :----------: | ------------------------------------------------------------------------------- |
| project | string | true | The project ID. |
| instance | string | true | The name of the instance to take a backup on. Does not include the project ID. |
| location | string | false | (Optional) Location of the backup run. |
| backup_description | string | false | (Optional) The description of this backup run. |
## See Also
- [Cloud SQL Admin API documentation](https://cloud.google.com/sql/docs/mysql/admin-api)
- [Toolbox Cloud SQL tools documentation](../cloudsql)
- [Cloud SQL Backup API documentation](https://cloud.google.com/sql/docs/mysql/backup-recovery/backups)

View File

@@ -365,7 +365,7 @@ pip install llama-index-llms-google-genai
{{< /tab >}}
{{< tab header="ADK" lang="bash" >}}
pip install toolbox-core
pip install toolbox-adk
{{< /tab >}}
{{< /tabpane >}}
@@ -607,8 +607,8 @@ from google.adk.agents import Agent
from google.adk.runners import Runner
from google.adk.sessions import InMemorySessionService
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
from google.adk.tools.toolbox_toolset import ToolboxToolset
from google.genai import types # For constructing message content
from toolbox_core import ToolboxSyncClient
import os
os.environ['GOOGLE_GENAI_USE_VERTEXAI'] = 'True'
@@ -623,48 +623,47 @@ os.environ['GOOGLE_CLOUD_LOCATION'] = 'us-central1'
# --- Load Tools from Toolbox ---
# TODO(developer): Ensure the Toolbox server is running at <http://127.0.0.1:5000>
# TODO(developer): Ensure the Toolbox server is running at http://127.0.0.1:5000
toolset = ToolboxToolset(server_url="http://127.0.0.1:5000")
with ToolboxSyncClient("<http://127.0.0.1:5000>") as toolbox_client:
# TODO(developer): Replace "my-toolset" with the actual ID of your toolset as configured in your MCP Toolbox server.
agent_toolset = toolbox_client.load_toolset("my-toolset")
# --- Define the Agent's Prompt ---
prompt = """
You're a helpful hotel assistant. You handle hotel searching, booking and
cancellations. When the user searches for a hotel, mention it's name, id,
location and price tier. Always mention hotel ids while performing any
searches. This is very important for any operations. For any bookings or
cancellations, please provide the appropriate confirmation. Be sure to
update checkin or checkout dates if mentioned by the user.
Don't ask for confirmations from the user.
"""
# --- Define the Agent's Prompt ---
prompt = """
You're a helpful hotel assistant. You handle hotel searching, booking and
cancellations. When the user searches for a hotel, mention it's name, id,
location and price tier. Always mention hotel ids while performing any
searches. This is very important for any operations. For any bookings or
cancellations, please provide the appropriate confirmation. Be sure to
update checkin or checkout dates if mentioned by the user.
Don't ask for confirmations from the user.
"""
# --- Configure the Agent ---
# --- Configure the Agent ---
root_agent = Agent(
model='gemini-2.0-flash-001',
name='hotel_agent',
description='A helpful AI assistant that can search and book hotels.',
instruction=prompt,
tools=[toolset], # Pass the loaded toolset
)
root_agent = Agent(
model='gemini-2.0-flash-001',
name='hotel_agent',
description='A helpful AI assistant that can search and book hotels.',
instruction=prompt,
tools=agent_toolset, # Pass the loaded toolset
)
# --- Initialize Services for Running the Agent ---
session_service = InMemorySessionService()
artifacts_service = InMemoryArtifactService()
# --- Initialize Services for Running the Agent ---
session_service = InMemorySessionService()
artifacts_service = InMemoryArtifactService()
runner = Runner(
app_name='hotel_agent',
agent=root_agent,
artifact_service=artifacts_service,
session_service=session_service,
)
async def main():
# Create a new session for the interaction.
session = session_service.create_session(
session = await session_service.create_session(
state={}, app_name='hotel_agent', user_id='123'
)
runner = Runner(
app_name='hotel_agent',
agent=root_agent,
artifact_service=artifacts_service,
session_service=session_service,
)
# --- Define Queries and Run the Agent ---
queries = [
"Find hotels in Basel with Basel in it's name.",
@@ -687,6 +686,10 @@ with ToolboxSyncClient("<http://127.0.0.1:5000>") as toolbox_client:
for text in responses:
print(text)
import asyncio
if __name__ == "__main__":
asyncio.run(main())
{{< /tab >}}
{{< /tabpane >}}

View File

@@ -21,13 +21,13 @@ import (
// AuthServiceConfig is the interface for configuring authentication services.
type AuthServiceConfig interface {
AuthServiceConfigType() string
AuthServiceConfigKind() string
Initialize() (AuthService, error)
}
// AuthService is the interface for authentication services.
type AuthService interface {
AuthServiceType() string
AuthServiceKind() string
GetName() string
GetClaimsFromHeader(context.Context, http.Header) (map[string]any, error)
ToConfig() AuthServiceConfig

View File

@@ -23,7 +23,7 @@ import (
"google.golang.org/api/idtoken"
)
const AuthServiceType string = "google"
const AuthServiceKind string = "google"
// validate interface
var _ auth.AuthServiceConfig = Config{}
@@ -31,13 +31,13 @@ var _ auth.AuthServiceConfig = Config{}
// Auth service configuration
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
ClientID string `yaml:"clientId" validate:"required"`
}
// Returns the auth service type
func (cfg Config) AuthServiceConfigType() string {
return AuthServiceType
// Returns the auth service kind
func (cfg Config) AuthServiceConfigKind() string {
return AuthServiceKind
}
// Initialize a Google auth service
@@ -55,9 +55,9 @@ type AuthService struct {
Config
}
// Returns the auth service type
func (a AuthService) AuthServiceType() string {
return AuthServiceType
// Returns the auth service kind
func (a AuthService) AuthServiceKind() string {
return AuthServiceKind
}
func (a AuthService) ToConfig() auth.AuthServiceConfig {

View File

@@ -22,12 +22,12 @@ import (
// EmbeddingModelConfig is the interface for configuring embedding models.
type EmbeddingModelConfig interface {
EmbeddingModelConfigType() string
EmbeddingModelConfigKind() string
Initialize(context.Context) (EmbeddingModel, error)
}
type EmbeddingModel interface {
EmbeddingModelType() string
EmbeddingModelKind() string
ToConfig() EmbeddingModelConfig
EmbedParameters(context.Context, []string) ([][]float32, error)
}

View File

@@ -23,22 +23,22 @@ import (
"google.golang.org/genai"
)
const EmbeddingModelType string = "gemini"
const EmbeddingModelKind string = "gemini"
// validate interface
var _ embeddingmodels.EmbeddingModelConfig = Config{}
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Model string `yaml:"model" validate:"required"`
ApiKey string `yaml:"apiKey"`
Dimension int32 `yaml:"dimension"`
}
// Returns the embedding model type
func (cfg Config) EmbeddingModelConfigType() string {
return EmbeddingModelType
// Returns the embedding model kind
func (cfg Config) EmbeddingModelConfigKind() string {
return EmbeddingModelKind
}
// Initialize a Gemini embedding model
@@ -69,9 +69,9 @@ type EmbeddingModel struct {
Config
}
// Returns the embedding model type
func (m EmbeddingModel) EmbeddingModelType() string {
return EmbeddingModelType
// Returns the embedding model kind
func (m EmbeddingModel) EmbeddingModelKind() string {
return EmbeddingModelKind
}
func (m EmbeddingModel) ToConfig() embeddingmodels.EmbeddingModelConfig {

View File

@@ -15,9 +15,9 @@
package gemini_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini"
@@ -34,15 +34,15 @@ func TestParseFromYamlGemini(t *testing.T) {
{
desc: "basic example",
in: `
kind: embeddingModels
name: my-gemini-model
type: gemini
model: text-embedding-004
embeddingModels:
my-gemini-model:
kind: gemini
model: text-embedding-004
`,
want: map[string]embeddingmodels.EmbeddingModelConfig{
"my-gemini-model": gemini.Config{
Name: "my-gemini-model",
Type: gemini.EmbeddingModelType,
Kind: gemini.EmbeddingModelKind,
Model: "text-embedding-004",
},
},
@@ -50,17 +50,17 @@ func TestParseFromYamlGemini(t *testing.T) {
{
desc: "full example with optional fields",
in: `
kind: embeddingModels
name: complex-gemini
type: gemini
model: text-embedding-004
apiKey: "test-api-key"
dimension: 768
embeddingModels:
complex-gemini:
kind: gemini
model: text-embedding-004
apiKey: "test-api-key"
dimension: 768
`,
want: map[string]embeddingmodels.EmbeddingModelConfig{
"complex-gemini": gemini.Config{
Name: "complex-gemini",
Type: gemini.EmbeddingModelType,
Kind: gemini.EmbeddingModelKind,
Model: "text-embedding-004",
ApiKey: "test-api-key",
Dimension: 768,
@@ -70,13 +70,16 @@ func TestParseFromYamlGemini(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Models server.EmbeddingModelConfigs `yaml:"embeddingModels"`
}{}
// Parse contents
_, _, got, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: %v", cmp.Diff(tc.want, got))
if !cmp.Equal(tc.want, got.Models) {
t.Fatalf("incorrect parse: %v", cmp.Diff(tc.want, got.Models))
}
})
}
@@ -90,29 +93,32 @@ func TestFailParseFromYamlGemini(t *testing.T) {
{
desc: "missing required model field",
in: `
kind: embeddingModels
name: bad-model
type: gemini
embeddingModels:
bad-model:
kind: gemini
`,
// Removed the specific model name from the prefix to match your output
err: "error unmarshaling embeddingModels: unable to parse as \"bad-model\": Key: 'Config.Model' Error:Field validation for 'Model' failed on the 'required' tag",
err: "unable to parse as \"gemini\": Key: 'Config.Model' Error:Field validation for 'Model' failed on the 'required' tag",
},
{
desc: "unknown field",
in: `
kind: embeddingModels
name: bad-field
type: gemini
model: text-embedding-004
invalid_param: true
embeddingModels:
bad-field:
kind: gemini
model: text-embedding-004
invalid_param: true
`,
// Updated to match the specific line-starting format of your error output
err: "error unmarshaling embeddingModels: unable to parse as \"bad-field\": [1:1] unknown field \"invalid_param\"\n> 1 | invalid_param: true\n ^\n 2 | model: text-embedding-004\n 3 | name: bad-field\n 4 | type: gemini",
err: "unable to parse as \"gemini\": [1:1] unknown field \"invalid_param\"\n> 1 | invalid_param: true\n ^\n 2 | kind: gemini\n 3 | model: text-embedding-004",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Models server.EmbeddingModelConfigs `yaml:"embeddingModels"`
}{}
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -43,9 +43,6 @@ tools:
clone_instance:
kind: cloud-sql-clone-instance
source: cloud-sql-admin-source
create_backup:
kind: cloud-sql-create-backup
source: cloud-sql-admin-source
toolsets:
cloud_sql_mssql_admin_tools:
@@ -57,4 +54,3 @@ toolsets:
- create_user
- wait_for_operation
- clone_instance
- create_backup

View File

@@ -43,9 +43,6 @@ tools:
clone_instance:
kind: cloud-sql-clone-instance
source: cloud-sql-admin-source
create_backup:
kind: cloud-sql-create-backup
source: cloud-sql-admin-source
toolsets:
cloud_sql_mysql_admin_tools:
@@ -57,4 +54,3 @@ toolsets:
- create_user
- wait_for_operation
- clone_instance
- create_backup

View File

@@ -46,9 +46,6 @@ tools:
postgres_upgrade_precheck:
kind: postgres-upgrade-precheck
source: cloud-sql-admin-source
create_backup:
kind: cloud-sql-create-backup
source: cloud-sql-admin-source
toolsets:
cloud_sql_postgres_admin_tools:
@@ -61,4 +58,3 @@ toolsets:
- wait_for_operation
- postgres_upgrade_precheck
- clone_instance
- create_backup

View File

@@ -25,12 +25,12 @@ import (
type Message = prompts.Message
const resourceType = "custom"
const kind = "custom"
// init registers this prompt type with the prompt framework.
// init registers this prompt kind with the prompt framework.
func init() {
if !prompts.Register(resourceType, newConfig) {
panic(fmt.Sprintf("prompt type %q already registered", resourceType))
if !prompts.Register(kind, newConfig) {
panic(fmt.Sprintf("prompt kind %q already registered", kind))
}
}
@@ -56,8 +56,8 @@ type Config struct {
var _ prompts.PromptConfig = Config{}
var _ prompts.Prompt = Prompt{}
func (c Config) PromptConfigType() string {
return resourceType
func (c Config) PromptConfigKind() string {
return kind
}
func (c Config) Initialize() (prompts.Prompt, error) {

View File

@@ -42,7 +42,7 @@ func TestConfig(t *testing.T) {
Arguments: testArgs,
}
// initialize and check type
// initialize and check kind
p, err := cfg.Initialize()
if err != nil {
t.Fatalf("Initialize() failed: %v", err)
@@ -50,8 +50,8 @@ func TestConfig(t *testing.T) {
if p == nil {
t.Fatal("Initialize() returned a nil prompt")
}
if cfg.PromptConfigType() != "custom" {
t.Errorf("PromptConfigType() = %q, want %q", cfg.PromptConfigType(), "custom")
if cfg.PromptConfigKind() != "custom" {
t.Errorf("PromptConfigKind() = %q, want %q", cfg.PromptConfigKind(), "custom")
}
t.Run("Manifest", func(t *testing.T) {

View File

@@ -30,40 +30,40 @@ var promptRegistry = make(map[string]PromptConfigFactory)
// Register allows individual prompt packages to register their configuration
// factory function. This is typically called from an init() function in the
// prompt's package. It associates a 'type' string with a function that can
// prompt's package. It associates a 'kind' string with a function that can
// produce the specific PromptConfig type. It returns true if the registration was
// successful, and false if a prompt with the same type was already registered.
func Register(resourceType string, factory PromptConfigFactory) bool {
if _, exists := promptRegistry[resourceType]; exists {
// Prompt with this type already exists, do not overwrite.
// successful, and false if a prompt with the same kind was already registered.
func Register(kind string, factory PromptConfigFactory) bool {
if _, exists := promptRegistry[kind]; exists {
// Prompt with this kind already exists, do not overwrite.
return false
}
promptRegistry[resourceType] = factory
promptRegistry[kind] = factory
return true
}
// DecodeConfig looks up the registered factory for the given type and uses it
// DecodeConfig looks up the registered factory for the given kind and uses it
// to decode the prompt configuration.
func DecodeConfig(ctx context.Context, resourceType, name string, decoder *yaml.Decoder) (PromptConfig, error) {
factory, found := promptRegistry[resourceType]
if !found && resourceType == "" {
resourceType = "custom"
factory, found = promptRegistry[resourceType]
func DecodeConfig(ctx context.Context, kind, name string, decoder *yaml.Decoder) (PromptConfig, error) {
factory, found := promptRegistry[kind]
if !found && kind == "" {
kind = "custom"
factory, found = promptRegistry[kind]
}
if !found {
return nil, fmt.Errorf("unknown prompt type: %q", resourceType)
return nil, fmt.Errorf("unknown prompt kind: %q", kind)
}
promptConfig, err := factory(ctx, name, decoder)
if err != nil {
return nil, fmt.Errorf("unable to parse prompt %q as resourceType %q: %w", name, resourceType, err)
return nil, fmt.Errorf("unable to parse prompt %q as kind %q: %w", name, kind, err)
}
return promptConfig, nil
}
type PromptConfig interface {
PromptConfigType() string
PromptConfigKind() string
Initialize() (Prompt, error)
}

View File

@@ -29,16 +29,16 @@ import (
type mockPromptConfig struct {
name string
Type string
kind string
}
func (m *mockPromptConfig) PromptConfigType() string { return m.Type }
func (m *mockPromptConfig) PromptConfigKind() string { return m.kind }
func (m *mockPromptConfig) Initialize() (prompts.Prompt, error) { return nil, nil }
var errMockFactory = errors.New("mock factory error")
func mockFactory(ctx context.Context, name string, decoder *yaml.Decoder) (prompts.PromptConfig, error) {
return &mockPromptConfig{name: name, Type: "mockType"}, nil
return &mockPromptConfig{name: name, kind: "mockKind"}, nil
}
func mockErrorFactory(ctx context.Context, name string, decoder *yaml.Decoder) (prompts.PromptConfig, error) {
@@ -50,17 +50,17 @@ func TestRegistry(t *testing.T) {
ctx := context.Background()
t.Run("RegisterAndDecodeSuccess", func(t *testing.T) {
resourceType := "testTypeSuccess"
if !prompts.Register(resourceType, mockFactory) {
kind := "testKindSuccess"
if !prompts.Register(kind, mockFactory) {
t.Fatal("expected registration to succeed")
}
// This should fail because we are registering a duplicate
if prompts.Register(resourceType, mockFactory) {
if prompts.Register(kind, mockFactory) {
t.Fatal("expected duplicate registration to fail")
}
decoder := yaml.NewDecoder(strings.NewReader(""))
config, err := prompts.DecodeConfig(ctx, resourceType, "testPrompt", decoder)
config, err := prompts.DecodeConfig(ctx, kind, "testPrompt", decoder)
if err != nil {
t.Fatalf("expected DecodeConfig to succeed, but got error: %v", err)
}
@@ -69,25 +69,25 @@ func TestRegistry(t *testing.T) {
}
})
t.Run("DecodeUnknownType", func(t *testing.T) {
t.Run("DecodeUnknownKind", func(t *testing.T) {
decoder := yaml.NewDecoder(strings.NewReader(""))
_, err := prompts.DecodeConfig(ctx, "unregisteredType", "testPrompt", decoder)
_, err := prompts.DecodeConfig(ctx, "unregisteredKind", "testPrompt", decoder)
if err == nil {
t.Fatal("expected an error for unknown type, but got nil")
t.Fatal("expected an error for unknown kind, but got nil")
}
if !strings.Contains(err.Error(), "unknown prompt type") {
t.Errorf("expected error to contain 'unknown prompt type', but got: %v", err)
if !strings.Contains(err.Error(), "unknown prompt kind") {
t.Errorf("expected error to contain 'unknown prompt kind', but got: %v", err)
}
})
t.Run("FactoryReturnsError", func(t *testing.T) {
resourceType := "testTypeError"
if !prompts.Register(resourceType, mockErrorFactory) {
kind := "testKindError"
if !prompts.Register(kind, mockErrorFactory) {
t.Fatal("expected registration to succeed")
}
decoder := yaml.NewDecoder(strings.NewReader(""))
_, err := prompts.DecodeConfig(ctx, resourceType, "testPrompt", decoder)
_, err := prompts.DecodeConfig(ctx, kind, "testPrompt", decoder)
if err == nil {
t.Fatal("expected an error from the factory, but got nil")
}
@@ -100,13 +100,13 @@ func TestRegistry(t *testing.T) {
decoder := yaml.NewDecoder(strings.NewReader("description: A test prompt"))
config, err := prompts.DecodeConfig(ctx, "", "testDefaultPrompt", decoder)
if err != nil {
t.Fatalf("expected DecodeConfig with empty type to succeed, but got error: %v", err)
t.Fatalf("expected DecodeConfig with empty kind to succeed, but got error: %v", err)
}
if config == nil {
t.Fatal("expected a non-nil config for default type")
t.Fatal("expected a non-nil config for default kind")
}
if config.PromptConfigType() != "custom" {
t.Errorf("expected default type to be 'custom', but got %q", config.PromptConfigType())
if config.PromptConfigKind() != "custom" {
t.Errorf("expected default kind to be 'custom', but got %q", config.PromptConfigKind())
}
})
}

View File

@@ -14,10 +14,8 @@
package server
import (
"bytes"
"context"
"fmt"
"io"
"strings"
yaml "github.com/goccy/go-yaml"
@@ -126,201 +124,272 @@ func (s *StringLevel) Type() string {
return "stringLevel"
}
// SourceConfigs is a type used to allow unmarshal of the data source config map
type SourceConfigs map[string]sources.SourceConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &SourceConfigs{}
func (c *SourceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(SourceConfigs)
// Parse the 'kind' fields for each source
var raw map[string]util.DelayedUnmarshaler
if err := unmarshal(&raw); err != nil {
return err
}
for name, u := range raw {
// Unmarshal to a general type that ensure it capture all fields
var v map[string]any
if err := u.Unmarshal(&v); err != nil {
return fmt.Errorf("unable to unmarshal %q: %w", name, err)
}
kind, ok := v["kind"]
if !ok {
return fmt.Errorf("missing 'kind' field for source %q", name)
}
kindStr, ok := kind.(string)
if !ok {
return fmt.Errorf("invalid 'kind' field for source %q (must be a string)", name)
}
yamlDecoder, err := util.NewStrictDecoder(v)
if err != nil {
return fmt.Errorf("error creating YAML decoder for source %q: %w", name, err)
}
sourceConfig, err := sources.DecodeConfig(ctx, kindStr, name, yamlDecoder)
if err != nil {
return err
}
(*c)[name] = sourceConfig
}
return nil
}
// AuthServiceConfigs is a type used to allow unmarshal of the data authService config map
type AuthServiceConfigs map[string]auth.AuthServiceConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &AuthServiceConfigs{}
func (c *AuthServiceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(AuthServiceConfigs)
// Parse the 'kind' fields for each authService
var raw map[string]util.DelayedUnmarshaler
if err := unmarshal(&raw); err != nil {
return err
}
for name, u := range raw {
var v map[string]any
if err := u.Unmarshal(&v); err != nil {
return fmt.Errorf("unable to unmarshal %q: %w", name, err)
}
kind, ok := v["kind"]
if !ok {
return fmt.Errorf("missing 'kind' field for %q", name)
}
dec, err := util.NewStrictDecoder(v)
if err != nil {
return fmt.Errorf("error creating decoder: %w", err)
}
switch kind {
case google.AuthServiceKind:
actual := google.Config{Name: name}
if err := dec.DecodeContext(ctx, &actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of auth source", kind)
}
}
return nil
}
// EmbeddingModelConfigs is a type used to allow unmarshal of the embedding model config map
type EmbeddingModelConfigs map[string]embeddingmodels.EmbeddingModelConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &EmbeddingModelConfigs{}
func (c *EmbeddingModelConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(EmbeddingModelConfigs)
// Parse the 'kind' fields for each embedding model
var raw map[string]util.DelayedUnmarshaler
if err := unmarshal(&raw); err != nil {
return err
}
for name, u := range raw {
// Unmarshal to a general type that ensure it capture all fields
var v map[string]any
if err := u.Unmarshal(&v); err != nil {
return fmt.Errorf("unable to unmarshal embedding model %q: %w", name, err)
}
kind, ok := v["kind"]
if !ok {
return fmt.Errorf("missing 'kind' field for embedding model %q", name)
}
dec, err := util.NewStrictDecoder(v)
if err != nil {
return fmt.Errorf("error creating decoder: %w", err)
}
switch kind {
case gemini.EmbeddingModelKind:
actual := gemini.Config{Name: name}
if err := dec.DecodeContext(ctx, &actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of auth source", kind)
}
}
return nil
}
// ToolConfigs is a type used to allow unmarshal of the tool configs
type ToolConfigs map[string]tools.ToolConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &ToolConfigs{}
func (c *ToolConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(ToolConfigs)
// Parse the 'kind' fields for each source
var raw map[string]util.DelayedUnmarshaler
if err := unmarshal(&raw); err != nil {
return err
}
for name, u := range raw {
var v map[string]any
if err := u.Unmarshal(&v); err != nil {
return fmt.Errorf("unable to unmarshal %q: %w", name, err)
}
// `authRequired` and `useClientOAuth` cannot be specified together
if v["authRequired"] != nil && v["useClientOAuth"] == true {
return fmt.Errorf("`authRequired` and `useClientOAuth` are mutually exclusive. Choose only one authentication method")
}
// Make `authRequired` an empty list instead of nil for Tool manifest
if v["authRequired"] == nil {
v["authRequired"] = []string{}
}
kindVal, ok := v["kind"]
if !ok {
return fmt.Errorf("missing 'kind' field for tool %q", name)
}
kindStr, ok := kindVal.(string)
if !ok {
return fmt.Errorf("invalid 'kind' field for tool %q (must be a string)", name)
}
yamlDecoder, err := util.NewStrictDecoder(v)
if err != nil {
return fmt.Errorf("error creating YAML decoder for tool %q: %w", name, err)
}
toolCfg, err := tools.DecodeConfig(ctx, kindStr, name, yamlDecoder)
if err != nil {
return err
}
(*c)[name] = toolCfg
}
return nil
}
// ToolsetConfigs is a type used to allow unmarshal of the toolset configs
type ToolsetConfigs map[string]tools.ToolsetConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &ToolsetConfigs{}
func (c *ToolsetConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(ToolsetConfigs)
var raw map[string][]string
if err := unmarshal(&raw); err != nil {
return err
}
for name, toolList := range raw {
(*c)[name] = tools.ToolsetConfig{Name: name, ToolNames: toolList}
}
return nil
}
// PromptConfigs is a type used to allow unmarshal of the prompt configs
type PromptConfigs map[string]prompts.PromptConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &PromptConfigs{}
func (c *PromptConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(PromptConfigs)
var raw map[string]util.DelayedUnmarshaler
if err := unmarshal(&raw); err != nil {
return err
}
for name, u := range raw {
var v map[string]any
if err := u.Unmarshal(&v); err != nil {
return fmt.Errorf("unable to unmarshal prompt %q: %w", name, err)
}
// Look for the 'kind' field. If it's not present, kindStr will be an
// empty string, which prompts.DecodeConfig will correctly default to "custom".
var kindStr string
if kindVal, ok := v["kind"]; ok {
var isString bool
kindStr, isString = kindVal.(string)
if !isString {
return fmt.Errorf("invalid 'kind' field for prompt %q (must be a string)", name)
}
}
// Create a new, strict decoder for this specific prompt's data.
yamlDecoder, err := util.NewStrictDecoder(v)
if err != nil {
return fmt.Errorf("error creating YAML decoder for prompt %q: %w", name, err)
}
// Use the central registry to decode the prompt based on its kind.
promptCfg, err := prompts.DecodeConfig(ctx, kindStr, name, yamlDecoder)
if err != nil {
return err
}
(*c)[name] = promptCfg
}
return nil
}
// PromptsetConfigs is a type used to allow unmarshal of the PromptsetConfigs configs
type PromptsetConfigs map[string]prompts.PromptsetConfig
func UnmarshalResourceConfig(ctx context.Context, raw []byte) (SourceConfigs, AuthServiceConfigs, EmbeddingModelConfigs, ToolConfigs, ToolsetConfigs, PromptConfigs, error) {
// prepare configs map
sourceConfigs := make(map[string]sources.SourceConfig)
authServiceConfigs := make(AuthServiceConfigs)
embeddingModelConfigs := make(EmbeddingModelConfigs)
toolConfigs := make(ToolConfigs)
toolsetConfigs := make(ToolsetConfigs)
promptConfigs := make(PromptConfigs)
// promptset configs is not yet supported
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &PromptsetConfigs{}
decoder := yaml.NewDecoder(bytes.NewReader(raw))
// for loop to unmarshal documents with the `---` separator
for {
var resource map[string]any
if err := decoder.DecodeContext(ctx, &resource); err != nil {
if err == io.EOF {
break
}
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unable to decode YAML document: %w", err)
}
var kind, name string
var ok bool
if kind, ok = resource["kind"].(string); !ok {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("missing 'kind' field or it is not a string")
}
if name, ok = resource["name"].(string); !ok {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("missing 'name' field or it is not a string")
}
// remove 'kind' from map for strict unmarshaling
delete(resource, "kind")
func (c *PromptsetConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(PromptsetConfigs)
switch kind {
case "sources":
c, err := UnmarshalYAMLSourceConfig(ctx, name, resource)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
}
sourceConfigs[name] = c
case "authServices":
c, err := UnmarshalYAMLAuthServiceConfig(ctx, name, resource)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
}
authServiceConfigs[name] = c
case "tools":
c, err := UnmarshalYAMLToolConfig(ctx, name, resource)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
}
toolConfigs[name] = c
case "toolsets":
c, err := UnmarshalYAMLToolsetConfig(ctx, name, resource)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
}
toolsetConfigs[name] = c
case "embeddingModels":
c, err := UnmarshalYAMLEmbeddingModelConfig(ctx, name, resource)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
}
embeddingModelConfigs[name] = c
case "prompts":
c, err := UnmarshalYAMLPromptConfig(ctx, name, resource)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
}
promptConfigs[name] = c
default:
return nil, nil, nil, nil, nil, nil, fmt.Errorf("invalid kind %s", kind)
}
}
return sourceConfigs, authServiceConfigs, embeddingModelConfigs, toolConfigs, toolsetConfigs, promptConfigs, nil
}
func UnmarshalYAMLSourceConfig(ctx context.Context, name string, r map[string]any) (sources.SourceConfig, error) {
resourceType, ok := r["type"].(string)
if !ok {
return nil, fmt.Errorf("missing 'type' field or it is not a string")
}
dec, err := util.NewStrictDecoder(r)
if err != nil {
return nil, fmt.Errorf("error creating decoder: %w", err)
}
sourceConfig, err := sources.DecodeConfig(ctx, resourceType, name, dec)
if err != nil {
return nil, err
}
return sourceConfig, nil
}
func UnmarshalYAMLAuthServiceConfig(ctx context.Context, name string, r map[string]any) (auth.AuthServiceConfig, error) {
resourceType, ok := r["type"].(string)
if !ok {
return nil, fmt.Errorf("missing 'type' field or it is not a string")
}
if resourceType != google.AuthServiceType {
return nil, fmt.Errorf("%s is not a valid type of auth service", resourceType)
}
dec, err := util.NewStrictDecoder(r)
if err != nil {
return nil, fmt.Errorf("error creating decoder: %s", err)
}
actual := google.Config{Name: name}
if err := dec.DecodeContext(ctx, &actual); err != nil {
return nil, fmt.Errorf("unable to parse as %s: %w", name, err)
}
return actual, nil
}
func UnmarshalYAMLEmbeddingModelConfig(ctx context.Context, name string, r map[string]any) (embeddingmodels.EmbeddingModelConfig, error) {
resourceType, ok := r["type"].(string)
if !ok {
return nil, fmt.Errorf("missing 'type' field or it is not a string")
}
if resourceType != gemini.EmbeddingModelType {
return nil, fmt.Errorf("%s is not a valid type of embedding model", resourceType)
}
dec, err := util.NewStrictDecoder(r)
if err != nil {
return nil, fmt.Errorf("error creating decoder: %s", err)
}
actual := gemini.Config{Name: name}
if err := dec.DecodeContext(ctx, &actual); err != nil {
return nil, fmt.Errorf("unable to parse as %q: %w", name, err)
}
return actual, nil
}
func UnmarshalYAMLToolConfig(ctx context.Context, name string, r map[string]any) (tools.ToolConfig, error) {
resourceType, ok := r["type"].(string)
if !ok {
return nil, fmt.Errorf("missing 'type' field or it is not a string")
}
// `authRequired` and `useClientOAuth` cannot be specified together
if r["authRequired"] != nil && r["useClientOAuth"] == true {
return nil, fmt.Errorf("`authRequired` and `useClientOAuth` are mutually exclusive. Choose only one authentication method")
}
// Make `authRequired` an empty list instead of nil for Tool manifest
if r["authRequired"] == nil {
r["authRequired"] = []string{}
}
dec, err := util.NewStrictDecoder(r)
if err != nil {
return nil, fmt.Errorf("error creating decoder: %s", err)
}
toolCfg, err := tools.DecodeConfig(ctx, resourceType, name, dec)
if err != nil {
return nil, err
}
return toolCfg, nil
}
func UnmarshalYAMLToolsetConfig(ctx context.Context, name string, r map[string]any) (tools.ToolsetConfig, error) {
var toolsetConfig tools.ToolsetConfig
justTools := map[string]any{"tools": r["tools"]}
dec, err := util.NewStrictDecoder(justTools)
if err != nil {
return toolsetConfig, fmt.Errorf("error creating decoder: %s", err)
}
var raw map[string][]string
if err := dec.DecodeContext(ctx, &raw); err != nil {
return toolsetConfig, fmt.Errorf("unable to unmarshal tools: %s", err)
}
return tools.ToolsetConfig{Name: name, ToolNames: raw["tools"]}, nil
}
func UnmarshalYAMLPromptConfig(ctx context.Context, name string, r map[string]any) (prompts.PromptConfig, error) {
// Look for the 'type' field. If it's not present, typeStr will be an
// empty string, which prompts.DecodeConfig will correctly default to "custom".
var resourceType string
if typeVal, ok := r["type"]; ok {
var isString bool
resourceType, isString = typeVal.(string)
if !isString {
return nil, fmt.Errorf("invalid 'type' field for prompt %q (must be a string)", name)
}
}
dec, err := util.NewStrictDecoder(r)
if err != nil {
return nil, fmt.Errorf("error creating decoder: %s", err)
if err := unmarshal(&raw); err != nil {
return err
}
// Use the central registry to decode the prompt based on its type.
promptCfg, err := prompts.DecodeConfig(ctx, resourceType, name, dec)
if err != nil {
return nil, err
for name, promptList := range raw {
(*c)[name] = prompts.PromptsetConfig{Name: name, PromptNames: promptList}
}
return promptCfg, nil
return nil
}

View File

@@ -32,7 +32,7 @@ func TestUpdateServer(t *testing.T) {
"example-source": &alloydbpg.Source{
Config: alloydbpg.Config{
Name: "example-alloydb-source",
Type: "alloydb-postgres",
Kind: "alloydb-postgres",
},
},
}
@@ -92,7 +92,7 @@ func TestUpdateServer(t *testing.T) {
"example-source2": &alloydbpg.Source{
Config: alloydbpg.Config{
Name: "example-alloydb-source2",
Type: "alloydb-postgres",
Kind: "alloydb-postgres",
},
},
}

View File

@@ -82,7 +82,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
childCtx, span := instrumentation.Tracer.Start(
ctx,
"toolbox/server/source/init",
trace.WithAttributes(attribute.String("source_type", sc.SourceConfigType())),
trace.WithAttributes(attribute.String("source_kind", sc.SourceConfigKind())),
trace.WithAttributes(attribute.String("source_name", name)),
)
defer span.End()
@@ -110,7 +110,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
_, span := instrumentation.Tracer.Start(
ctx,
"toolbox/server/auth/init",
trace.WithAttributes(attribute.String("auth_type", sc.AuthServiceConfigType())),
trace.WithAttributes(attribute.String("auth_kind", sc.AuthServiceConfigKind())),
trace.WithAttributes(attribute.String("auth_name", name)),
)
defer span.End()
@@ -138,7 +138,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
_, span := instrumentation.Tracer.Start(
ctx,
"toolbox/server/embeddingmodel/init",
trace.WithAttributes(attribute.String("model_type", ec.EmbeddingModelConfigType())),
trace.WithAttributes(attribute.String("model_kind", ec.EmbeddingModelConfigKind())),
trace.WithAttributes(attribute.String("model_name", name)),
)
defer span.End()
@@ -166,7 +166,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
_, span := instrumentation.Tracer.Start(
ctx,
"toolbox/server/tool/init",
trace.WithAttributes(attribute.String("tool_type", tc.ToolConfigType())),
trace.WithAttributes(attribute.String("tool_kind", tc.ToolConfigKind())),
trace.WithAttributes(attribute.String("tool_name", name)),
)
defer span.End()
@@ -235,7 +235,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
_, span := instrumentation.Tracer.Start(
ctx,
"toolbox/server/prompt/init",
trace.WithAttributes(attribute.String("prompt_type", pc.PromptConfigType())),
trace.WithAttributes(attribute.String("prompt_kind", pc.PromptConfigKind())),
trace.WithAttributes(attribute.String("prompt_name", name)),
)
defer span.End()

View File

@@ -141,7 +141,7 @@ func TestUpdateServer(t *testing.T) {
"example-source": &alloydbpg.Source{
Config: alloydbpg.Config{
Name: "example-alloydb-source",
Type: "alloydb-postgres",
Kind: "alloydb-postgres",
},
},
}

View File

@@ -32,14 +32,14 @@ import (
"google.golang.org/api/option"
)
const SourceType string = "alloydb-admin"
const SourceKind string = "alloydb-admin"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -53,13 +53,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
DefaultProject string `yaml:"defaultProject"`
UseClientOAuth bool `yaml:"useClientOAuth"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -106,8 +106,8 @@ type Source struct {
Service *alloydbrestapi.Service
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -15,9 +15,9 @@
package alloydbadmin_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -34,14 +34,14 @@ func TestParseFromYamlAlloyDBAdmin(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-alloydb-admin-instance
type: alloydb-admin
sources:
my-alloydb-admin-instance:
kind: alloydb-admin
`,
want: map[string]sources.SourceConfig{
"my-alloydb-admin-instance": alloydbadmin.Config{
Name: "my-alloydb-admin-instance",
Type: alloydbadmin.SourceType,
Kind: alloydbadmin.SourceKind,
UseClientOAuth: false,
},
},
@@ -49,15 +49,15 @@ func TestParseFromYamlAlloyDBAdmin(t *testing.T) {
{
desc: "use client auth example",
in: `
kind: sources
name: my-alloydb-admin-instance
type: alloydb-admin
useClientOAuth: true
sources:
my-alloydb-admin-instance:
kind: alloydb-admin
useClientOAuth: true
`,
want: map[string]sources.SourceConfig{
"my-alloydb-admin-instance": alloydbadmin.Config{
Name: "my-alloydb-admin-instance",
Type: alloydbadmin.SourceType,
Kind: alloydbadmin.SourceKind,
UseClientOAuth: true,
},
},
@@ -65,13 +65,16 @@ func TestParseFromYamlAlloyDBAdmin(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -86,27 +89,30 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-alloydb-admin-instance
type: alloydb-admin
project: test-project
sources:
my-alloydb-admin-instance:
kind: alloydb-admin
project: test-project
`,
err: "error unmarshaling sources: unable to parse source \"my-alloydb-admin-instance\" as \"alloydb-admin\": [2:1] unknown field \"project\"\n 1 | name: my-alloydb-admin-instance\n> 2 | project: test-project\n ^\n 3 | type: alloydb-admin",
err: "unable to parse source \"my-alloydb-admin-instance\" as \"alloydb-admin\": [2:1] unknown field \"project\"\n 1 | kind: alloydb-admin\n> 2 | project: test-project\n ^\n",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-alloydb-admin-instance
useClientOAuth: true
sources:
my-alloydb-admin-instance:
useClientOAuth: true
`,
err: "error unmarshaling sources: missing 'type' field or it is not a string",
err: "missing 'kind' field for source \"my-alloydb-admin-instance\"",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -29,14 +29,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "alloydb-postgres"
const SourceKind string = "alloydb-postgres"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -50,7 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Region string `yaml:"region" validate:"required"`
Cluster string `yaml:"cluster" validate:"required"`
@@ -61,8 +61,8 @@ type Config struct {
Database string `yaml:"database" validate:"required"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -90,8 +90,8 @@ type Source struct {
Pool *pgxpool.Pool
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -183,7 +183,7 @@ func getConnectionConfig(ctx context.Context, user, pass, dbname string) (string
func initAlloyDBPgConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, cluster, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
dsn, useIAM, err := getConnectionConfig(ctx, user, pass, dbname)

View File

@@ -15,9 +15,9 @@
package alloydbpg_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -34,21 +34,21 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-pg-instance
type: alloydb-postgres
project: my-project
region: my-region
cluster: my-cluster
instance: my-instance
database: my_db
user: my_user
password: my_pass
sources:
my-pg-instance:
kind: alloydb-postgres
project: my-project
region: my-region
cluster: my-cluster
instance: my-instance
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
"my-pg-instance": alloydbpg.Config{
Name: "my-pg-instance",
Type: alloydbpg.SourceType,
Kind: alloydbpg.SourceKind,
Project: "my-project",
Region: "my-region",
Cluster: "my-cluster",
@@ -63,22 +63,22 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
{
desc: "public ipType",
in: `
kind: sources
name: my-pg-instance
type: alloydb-postgres
project: my-project
region: my-region
cluster: my-cluster
instance: my-instance
ipType: Public
database: my_db
user: my_user
password: my_pass
sources:
my-pg-instance:
kind: alloydb-postgres
project: my-project
region: my-region
cluster: my-cluster
instance: my-instance
ipType: Public
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
"my-pg-instance": alloydbpg.Config{
Name: "my-pg-instance",
Type: alloydbpg.SourceType,
Kind: alloydbpg.SourceKind,
Project: "my-project",
Region: "my-region",
Cluster: "my-cluster",
@@ -93,22 +93,22 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
{
desc: "private ipType",
in: `
kind: sources
name: my-pg-instance
type: alloydb-postgres
project: my-project
region: my-region
cluster: my-cluster
instance: my-instance
ipType: private
database: my_db
user: my_user
password: my_pass
sources:
my-pg-instance:
kind: alloydb-postgres
project: my-project
region: my-region
cluster: my-cluster
instance: my-instance
ipType: private
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
"my-pg-instance": alloydbpg.Config{
Name: "my-pg-instance",
Type: alloydbpg.SourceType,
Kind: alloydbpg.SourceKind,
Project: "my-project",
Region: "my-region",
Cluster: "my-cluster",
@@ -123,13 +123,16 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -144,56 +147,60 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "invalid ipType",
in: `
kind: sources
name: my-pg-instance
type: alloydb-postgres
project: my-project
region: my-region
cluster: my-cluster
instance: my-instance
ipType: fail
database: my_db
user: my_user
password: my_pass
sources:
my-pg-instance:
kind: alloydb-postgres
project: my-project
region: my-region
cluster: my-cluster
instance: my-instance
ipType: fail
database: my_db
user: my_user
password: my_pass
`,
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"alloydb-postgres\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
err: "unable to parse source \"my-pg-instance\" as \"alloydb-postgres\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
},
{
desc: "extra field",
in: `
kind: sources
name: my-pg-instance
type: alloydb-postgres
project: my-project
region: my-region
cluster: my-cluster
instance: my-instance
database: my_db
user: my_user
password: my_pass
foo: bar
sources:
my-pg-instance:
kind: alloydb-postgres
project: my-project
region: my-region
cluster: my-cluster
instance: my-instance
database: my_db
user: my_user
password: my_pass
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"alloydb-postgres\": [3:1] unknown field \"foo\"\n 1 | cluster: my-cluster\n 2 | database: my_db\n> 3 | foo: bar\n ^\n 4 | instance: my-instance\n 5 | name: my-pg-instance\n 6 | password: my_pass\n 7 | ",
err: "unable to parse source \"my-pg-instance\" as \"alloydb-postgres\": [3:1] unknown field \"foo\"\n 1 | cluster: my-cluster\n 2 | database: my_db\n> 3 | foo: bar\n ^\n 4 | instance: my-instance\n 5 | kind: alloydb-postgres\n 6 | password: my_pass\n 7 | ",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-pg-instance
type: alloydb-postgres
region: my-region
cluster: my-cluster
instance: my-instance
database: my_db
user: my_user
password: my_pass
sources:
my-pg-instance:
kind: alloydb-postgres
region: my-region
cluster: my-cluster
instance: my-instance
database: my_db
user: my_user
password: my_pass
`,
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"alloydb-postgres\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
err: "unable to parse source \"my-pg-instance\" as \"alloydb-postgres\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -41,7 +41,7 @@ import (
"google.golang.org/api/option"
)
const SourceType string = "bigquery"
const SourceKind string = "bigquery"
// CloudPlatformScope is a broad scope for Google Cloud Platform services.
const CloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
@@ -65,8 +65,8 @@ type BigQuerySessionProvider func(ctx context.Context) (*Session, error)
type DataplexClientCreator func(tokenString string) (*dataplexapi.CatalogClient, error)
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -81,7 +81,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
// BigQuery configs
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Location string `yaml:"location"`
WriteMode string `yaml:"writeMode"`
@@ -119,9 +119,9 @@ func (s *StringOrStringSlice) UnmarshalYAML(unmarshal func(any) error) error {
return fmt.Errorf("cannot unmarshal %T into StringOrStringSlice", v)
}
func (r Config) SourceConfigType() string {
// Returns BigQuery source type
return SourceType
func (r Config) SourceConfigKind() string {
// Returns BigQuery source kind
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
if r.WriteMode == "" {
@@ -302,9 +302,9 @@ type Session struct {
LastUsed time.Time
}
func (s *Source) SourceType() string {
// Returns BigQuery Google SQL source type
return SourceType
func (s *Source) SourceKind() string {
// Returns BigQuery Google SQL source kind
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -665,7 +665,7 @@ func initBigQueryConnection(
impersonateServiceAccount string,
scopes []string,
) (*bigqueryapi.Client, *bigqueryrestapi.Service, oauth2.TokenSource, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
userAgent, err := util.UserAgentFromContext(ctx)
@@ -741,7 +741,7 @@ func initBigQueryConnectionWithOAuthToken(
tokenString string,
wantRestService bool,
) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Construct token source
token := &oauth2.Token{
@@ -801,7 +801,7 @@ func initDataplexConnection(
var clientCreator DataplexClientCreator
var err error
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
userAgent, err := util.UserAgentFromContext(ctx)

View File

@@ -15,18 +15,18 @@
package bigquery_test
import (
"context"
"math/big"
"reflect"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"go.opentelemetry.io/otel/trace/noop"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace/noop"
)
func TestParseFromYamlBigQuery(t *testing.T) {
@@ -38,15 +38,15 @@ func TestParseFromYamlBigQuery(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-instance
type: bigquery
project: my-project
sources:
my-instance:
kind: bigquery
project: my-project
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-instance": bigquery.Config{
Name: "my-instance",
Type: bigquery.SourceType,
Kind: bigquery.SourceKind,
Project: "my-project",
Location: "",
WriteMode: "",
@@ -56,17 +56,17 @@ func TestParseFromYamlBigQuery(t *testing.T) {
{
desc: "all fields specified",
in: `
kind: sources
name: my-instance
type: bigquery
project: my-project
location: asia
writeMode: blocked
sources:
my-instance:
kind: bigquery
project: my-project
location: asia
writeMode: blocked
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-instance": bigquery.Config{
Name: "my-instance",
Type: bigquery.SourceType,
Kind: bigquery.SourceKind,
Project: "my-project",
Location: "asia",
WriteMode: "blocked",
@@ -77,17 +77,17 @@ func TestParseFromYamlBigQuery(t *testing.T) {
{
desc: "use client auth example",
in: `
kind: sources
name: my-instance
type: bigquery
project: my-project
location: us
useClientOAuth: true
sources:
my-instance:
kind: bigquery
project: my-project
location: us
useClientOAuth: true
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-instance": bigquery.Config{
Name: "my-instance",
Type: bigquery.SourceType,
Kind: bigquery.SourceKind,
Project: "my-project",
Location: "us",
UseClientOAuth: true,
@@ -97,18 +97,18 @@ func TestParseFromYamlBigQuery(t *testing.T) {
{
desc: "with allowed datasets example",
in: `
kind: sources
name: my-instance
type: bigquery
project: my-project
location: us
allowedDatasets:
- my_dataset
sources:
my-instance:
kind: bigquery
project: my-project
location: us
allowedDatasets:
- my_dataset
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-instance": bigquery.Config{
Name: "my-instance",
Type: bigquery.SourceType,
Kind: bigquery.SourceKind,
Project: "my-project",
Location: "us",
AllowedDatasets: []string{"my_dataset"},
@@ -118,17 +118,17 @@ func TestParseFromYamlBigQuery(t *testing.T) {
{
desc: "with service account impersonation example",
in: `
kind: sources
name: my-instance
type: bigquery
project: my-project
location: us
impersonateServiceAccount: service-account@my-project.iam.gserviceaccount.com
sources:
my-instance:
kind: bigquery
project: my-project
location: us
impersonateServiceAccount: service-account@my-project.iam.gserviceaccount.com
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-instance": bigquery.Config{
Name: "my-instance",
Type: bigquery.SourceType,
Kind: bigquery.SourceKind,
Project: "my-project",
Location: "us",
ImpersonateServiceAccount: "service-account@my-project.iam.gserviceaccount.com",
@@ -138,19 +138,19 @@ func TestParseFromYamlBigQuery(t *testing.T) {
{
desc: "with custom scopes example",
in: `
kind: sources
name: my-instance
type: bigquery
project: my-project
location: us
scopes:
- https://www.googleapis.com/auth/bigquery
- https://www.googleapis.com/auth/cloud-platform
sources:
my-instance:
kind: bigquery
project: my-project
location: us
scopes:
- https://www.googleapis.com/auth/bigquery
- https://www.googleapis.com/auth/cloud-platform
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-instance": bigquery.Config{
Name: "my-instance",
Type: bigquery.SourceType,
Kind: bigquery.SourceKind,
Project: "my-project",
Location: "us",
Scopes: []string{"https://www.googleapis.com/auth/bigquery", "https://www.googleapis.com/auth/cloud-platform"},
@@ -160,17 +160,17 @@ func TestParseFromYamlBigQuery(t *testing.T) {
{
desc: "with max query result rows example",
in: `
kind: sources
name: my-instance
type: bigquery
project: my-project
location: us
maxQueryResultRows: 10
sources:
my-instance:
kind: bigquery
project: my-project
location: us
maxQueryResultRows: 10
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-instance": bigquery.Config{
Name: "my-instance",
Type: bigquery.SourceType,
Kind: bigquery.SourceKind,
Project: "my-project",
Location: "us",
MaxQueryResultRows: 10,
@@ -180,15 +180,20 @@ func TestParseFromYamlBigQuery(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Fatalf("incorrect parse (-want +got):\n%s", diff)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
}
func TestFailParseFromYaml(t *testing.T) {
@@ -200,29 +205,33 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-instance
type: bigquery
project: my-project
location: us
foo: bar
sources:
my-instance:
kind: bigquery
project: my-project
location: us
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"bigquery\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | location: us\n 3 | name: my-instance\n 4 | project: my-project\n 5 | ",
err: "unable to parse source \"my-instance\" as \"bigquery\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | kind: bigquery\n 3 | location: us\n 4 | project: my-project",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-instance
type: bigquery
location: us
sources:
my-instance:
kind: bigquery
location: us
`,
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"bigquery\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
err: "unable to parse source \"my-instance\" as \"bigquery\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}
@@ -251,7 +260,7 @@ func TestInitialize_MaxQueryResultRows(t *testing.T) {
desc: "default value",
cfg: bigquery.Config{
Name: "test-default",
Type: bigquery.SourceType,
Kind: bigquery.SourceKind,
Project: "test-project",
UseClientOAuth: true,
},
@@ -261,7 +270,7 @@ func TestInitialize_MaxQueryResultRows(t *testing.T) {
desc: "configured value",
cfg: bigquery.Config{
Name: "test-configured",
Type: bigquery.SourceType,
Kind: bigquery.SourceKind,
Project: "test-project",
UseClientOAuth: true,
MaxQueryResultRows: 100,

View File

@@ -27,14 +27,14 @@ import (
"google.golang.org/api/option"
)
const SourceType string = "bigtable"
const SourceKind string = "bigtable"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -48,13 +48,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Instance string `yaml:"instance" validate:"required"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -77,8 +77,8 @@ type Source struct {
Client *bigtable.Client
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -179,7 +179,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, configParam param
func initBigtableClient(ctx context.Context, tracer trace.Tracer, name, project, instance string) (*bigtable.Client, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Set up Bigtable data operations client.

View File

@@ -15,9 +15,9 @@
package bigtable_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -34,16 +34,16 @@ func TestParseFromYamlBigtableDb(t *testing.T) {
{
desc: "can configure with a bigtable table",
in: `
kind: sources
name: my-bigtable-instance
type: bigtable
project: my-project
instance: my-instance
sources:
my-bigtable-instance:
kind: bigtable
project: my-project
instance: my-instance
`,
want: map[string]sources.SourceConfig{
"my-bigtable-instance": bigtable.Config{
Name: "my-bigtable-instance",
Type: bigtable.SourceType,
Kind: bigtable.SourceKind,
Project: "my-project",
Instance: "my-instance",
},
@@ -52,12 +52,16 @@ func TestParseFromYamlBigtableDb(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -73,29 +77,33 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-bigtable-instance
type: bigtable
project: my-project
instance: my-instance
foo: bar
sources:
my-bigtable-instance:
kind: bigtable
project: my-project
instance: my-instance
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-bigtable-instance\" as \"bigtable\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | instance: my-instance\n 3 | name: my-bigtable-instance\n 4 | project: my-project\n 5 | ",
err: "unable to parse source \"my-bigtable-instance\" as \"bigtable\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | instance: my-instance\n 3 | kind: bigtable\n 4 | project: my-project",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-bigtable-instance
type: bigtable
project: my-project
sources:
my-bigtable-instance:
kind: bigtable
project: my-project
`,
err: "error unmarshaling sources: unable to parse source \"my-bigtable-instance\" as \"bigtable\": Key: 'Config.Instance' Error:Field validation for 'Instance' failed on the 'required' tag",
err: "unable to parse source \"my-bigtable-instance\" as \"bigtable\": Key: 'Config.Instance' Error:Field validation for 'Instance' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -25,11 +25,11 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "cassandra"
const SourceKind string = "cassandra"
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -43,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Hosts []string `yaml:"hosts" validate:"required"`
Keyspace string `yaml:"keyspace"`
ProtoVersion int `yaml:"protoVersion"`
@@ -68,9 +68,9 @@ func (c Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return s, nil
}
// SourceConfigType implements sources.SourceConfig.
func (c Config) SourceConfigType() string {
return SourceType
// SourceConfigKind implements sources.SourceConfig.
func (c Config) SourceConfigKind() string {
return SourceKind
}
var _ sources.SourceConfig = Config{}
@@ -89,9 +89,9 @@ func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
// SourceType implements sources.Source.
func (s *Source) SourceType() string {
return SourceType
// SourceKind implements sources.Source.
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) RunSQL(ctx context.Context, statement string, params parameters.ParamValues) (any, error) {
@@ -120,7 +120,7 @@ var _ sources.Source = &Source{}
func initCassandraSession(ctx context.Context, tracer trace.Tracer, c Config) (*gocql.Session, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, c.Name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, c.Name)
defer span.End()
// Validate authentication configuration

View File

@@ -15,12 +15,11 @@
package cassandra_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cassandra"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -34,17 +33,17 @@ func TestParseFromYamlCassandra(t *testing.T) {
{
desc: "basic example (without optional fields)",
in: `
kind: sources
name: my-cassandra-instance
type: cassandra
hosts:
- "my-host1"
- "my-host2"
sources:
my-cassandra-instance:
kind: cassandra
hosts:
- "my-host1"
- "my-host2"
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-cassandra-instance": cassandra.Config{
Name: "my-cassandra-instance",
Type: cassandra.SourceType,
Kind: cassandra.SourceKind,
Hosts: []string{"my-host1", "my-host2"},
Username: "",
Password: "",
@@ -60,25 +59,25 @@ func TestParseFromYamlCassandra(t *testing.T) {
{
desc: "with optional fields",
in: `
kind: sources
name: my-cassandra-instance
type: cassandra
hosts:
- "my-host1"
- "my-host2"
username: "user"
password: "pass"
keyspace: "example_keyspace"
protoVersion: 4
caPath: "path/to/ca.crt"
certPath: "path/to/cert"
keyPath: "path/to/key"
enableHostVerification: true
sources:
my-cassandra-instance:
kind: cassandra
hosts:
- "my-host1"
- "my-host2"
username: "user"
password: "pass"
keyspace: "example_keyspace"
protoVersion: 4
caPath: "path/to/ca.crt"
certPath: "path/to/cert"
keyPath: "path/to/key"
enableHostVerification: true
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-cassandra-instance": cassandra.Config{
Name: "my-cassandra-instance",
Type: cassandra.SourceType,
Kind: cassandra.SourceKind,
Hosts: []string{"my-host1", "my-host2"},
Username: "user",
Password: "pass",
@@ -94,12 +93,16 @@ func TestParseFromYamlCassandra(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -115,29 +118,33 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-cassandra-instance
type: cassandra
hosts:
- "my-host"
foo: bar
sources:
my-cassandra-instance:
kind: cassandra
hosts:
- "my-host"
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-cassandra-instance\" as \"cassandra\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | hosts:\n 3 | - my-host\n 4 | name: my-cassandra-instance\n 5 | ",
err: "unable to parse source \"my-cassandra-instance\" as \"cassandra\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | hosts:\n 3 | - my-host\n 4 | kind: cassandra",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-cassandra-instance
type: cassandra
sources:
my-cassandra-instance:
kind: cassandra
`,
err: "error unmarshaling sources: unable to parse source \"my-cassandra-instance\" as \"cassandra\": Key: 'Config.Hosts' Error:Field validation for 'Hosts' failed on the 'required' tag",
err: "unable to parse source \"my-cassandra-instance\" as \"cassandra\": Key: 'Config.Hosts' Error:Field validation for 'Hosts' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -28,14 +28,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "clickhouse"
const SourceKind string = "clickhouse"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -49,7 +49,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
Database string `yaml:"database" validate:"required"`
@@ -59,8 +59,8 @@ type Config struct {
Secure bool `yaml:"secure"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -88,8 +88,8 @@ type Source struct {
Pool *sql.DB
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -174,7 +174,7 @@ func validateConfig(protocol string) error {
func initClickHouseConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, protocol string, secure bool) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
if protocol == "" {

View File

@@ -21,113 +21,137 @@ import (
"github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/testutils"
"go.opentelemetry.io/otel"
)
func TestParseFromYamlClickhouse(t *testing.T) {
tcs := []struct {
desc string
in string
want server.SourceConfigs
}{
{
desc: "all fields specified",
in: `
kind: sources
name: test-clickhouse
type: clickhouse
host: localhost
port: "8443"
user: default
password: "mypass"
database: mydb
protocol: https
secure: true
`,
want: map[string]sources.SourceConfig{
"test-clickhouse": Config{
Name: "test-clickhouse",
Type: "clickhouse",
Host: "localhost",
Port: "8443",
User: "default",
Password: "mypass",
Database: "mydb",
Protocol: "https",
Secure: true,
},
},
},
{
desc: "minimal configuration with defaults",
in: `
kind: sources
name: minimal-clickhouse
type: clickhouse
host: 127.0.0.1
port: "8123"
user: testuser
database: testdb
`,
want: map[string]sources.SourceConfig{
"minimal-clickhouse": Config{
Name: "minimal-clickhouse",
Type: "clickhouse",
Host: "127.0.0.1",
Port: "8123",
User: "testuser",
Password: "",
Database: "testdb",
Protocol: "",
Secure: false,
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
func TestConfigSourceConfigKind(t *testing.T) {
config := Config{}
if config.SourceConfigKind() != SourceKind {
t.Errorf("Expected %s, got %s", SourceKind, config.SourceConfigKind())
}
}
func TestFailParseFromYaml(t *testing.T) {
tcs := []struct {
desc string
in string
err string
func TestNewConfig(t *testing.T) {
tests := []struct {
name string
yaml string
expected Config
}{
{
desc: "extra field",
in: `
kind: sources
name: test-clickhouse
type: clickhouse
host: localhost
foo: bar
name: "all fields specified",
yaml: `
name: test-clickhouse
kind: clickhouse
host: localhost
port: "8443"
user: default
password: "mypass"
database: mydb
protocol: https
secure: true
`,
err: "error unmarshaling sources: unable to parse source \"test-clickhouse\" as \"clickhouse\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | host: localhost\n 3 | name: test-clickhouse\n 4 | type: clickhouse",
expected: Config{
Name: "test-clickhouse",
Kind: "clickhouse",
Host: "localhost",
Port: "8443",
User: "default",
Password: "mypass",
Database: "mydb",
Protocol: "https",
Secure: true,
},
},
{
name: "minimal configuration with defaults",
yaml: `
name: minimal-clickhouse
kind: clickhouse
host: 127.0.0.1
port: "8123"
user: testuser
database: testdb
`,
expected: Config{
Name: "minimal-clickhouse",
Kind: "clickhouse",
Host: "127.0.0.1",
Port: "8123",
User: "testuser",
Password: "",
Database: "testdb",
Protocol: "",
Secure: false,
},
},
{
name: "http protocol",
yaml: `
name: http-clickhouse
kind: clickhouse
host: clickhouse.example.com
port: "8123"
user: analytics
password: "securepass"
database: analytics_db
protocol: http
secure: false
`,
expected: Config{
Name: "http-clickhouse",
Kind: "clickhouse",
Host: "clickhouse.example.com",
Port: "8123",
User: "analytics",
Password: "securepass",
Database: "analytics_db",
Protocol: "http",
Secure: false,
},
},
{
name: "https with secure connection",
yaml: `
name: secure-clickhouse
kind: clickhouse
host: secure.clickhouse.io
port: "8443"
user: secureuser
password: "verysecure"
database: production
protocol: https
secure: true
`,
expected: Config{
Name: "secure-clickhouse",
Kind: "clickhouse",
Host: "secure.clickhouse.io",
Port: "8443",
User: "secureuser",
Password: "verysecure",
Database: "production",
Protocol: "https",
Secure: true,
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
decoder := yaml.NewDecoder(strings.NewReader(string(testutils.FormatYaml(tt.yaml))))
config, err := newConfig(context.Background(), tt.expected.Name, decoder)
if err != nil {
t.Fatalf("Failed to create config: %v", err)
}
errStr := err.Error()
if errStr != tc.err {
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
clickhouseConfig, ok := config.(Config)
if !ok {
t.Fatalf("Expected Config type, got %T", config)
}
if diff := cmp.Diff(tt.expected, clickhouseConfig); diff != "" {
t.Errorf("Config mismatch (-want +got):\n%s", diff)
}
})
}
@@ -143,11 +167,19 @@ func TestNewConfigInvalidYAML(t *testing.T) {
name: "invalid yaml syntax",
yaml: `
name: test-clickhouse
type: clickhouse
kind: clickhouse
host: [invalid
`,
expectError: true,
},
{
name: "missing required fields",
yaml: `
name: test-clickhouse
kind: clickhouse
`,
expectError: false,
},
}
for _, tt := range tests {
@@ -164,10 +196,10 @@ func TestNewConfigInvalidYAML(t *testing.T) {
}
}
func TestSource_SourceType(t *testing.T) {
func TestSource_SourceKind(t *testing.T) {
source := &Source{}
if source.SourceType() != SourceType {
t.Errorf("Expected %s, got %s", SourceType, source.SourceType())
if source.SourceKind() != SourceKind {
t.Errorf("Expected %s, got %s", SourceKind, source.SourceKind())
}
}

View File

@@ -29,15 +29,15 @@ import (
"golang.org/x/oauth2/google"
)
const SourceType string = "cloud-gemini-data-analytics"
const SourceKind string = "cloud-gemini-data-analytics"
const Endpoint string = "https://geminidataanalytics.googleapis.com"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -51,13 +51,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
ProjectID string `yaml:"projectId" validate:"required"`
UseClientOAuth bool `yaml:"useClientOAuth"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
// Initialize initializes a Gemini Data Analytics Source instance.
@@ -102,8 +102,8 @@ type Source struct {
userAgent string
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -20,6 +20,7 @@ import (
"path/filepath"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -38,15 +39,15 @@ func TestParseFromYamlCloudGDA(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-gda-instance
type: cloud-gemini-data-analytics
projectId: test-project-id
`,
sources:
my-gda-instance:
kind: cloud-gemini-data-analytics
projectId: test-project-id
`,
want: map[string]sources.SourceConfig{
"my-gda-instance": cloudgda.Config{
Name: "my-gda-instance",
Type: cloudgda.SourceType,
Kind: cloudgda.SourceKind,
ProjectID: "test-project-id",
UseClientOAuth: false,
},
@@ -55,16 +56,16 @@ func TestParseFromYamlCloudGDA(t *testing.T) {
{
desc: "use client auth example",
in: `
kind: sources
name: my-gda-instance
type: cloud-gemini-data-analytics
projectId: another-project
useClientOAuth: true
sources:
my-gda-instance:
kind: cloud-gemini-data-analytics
projectId: another-project
useClientOAuth: true
`,
want: map[string]sources.SourceConfig{
"my-gda-instance": cloudgda.Config{
Name: "my-gda-instance",
Type: cloudgda.SourceType,
Kind: cloudgda.SourceKind,
ProjectID: "another-project",
UseClientOAuth: true,
},
@@ -75,12 +76,16 @@ func TestParseFromYamlCloudGDA(t *testing.T) {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -96,18 +101,22 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "missing projectId",
in: `
kind: sources
name: my-gda-instance
type: cloud-gemini-data-analytics
sources:
my-gda-instance:
kind: cloud-gemini-data-analytics
`,
err: "error unmarshaling sources: unable to parse source \"my-gda-instance\" as \"cloud-gemini-data-analytics\": Key: 'Config.ProjectID' Error:Field validation for 'ProjectID' failed on the 'required' tag",
err: "unable to parse source \"my-gda-instance\" as \"cloud-gemini-data-analytics\": Key: 'Config.ProjectID' Error:Field validation for 'ProjectID' failed on the 'required' tag",
},
}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}
@@ -144,12 +153,12 @@ func TestInitialize(t *testing.T) {
}{
{
desc: "initialize with ADC",
cfg: cloudgda.Config{Name: "test-gda", Type: cloudgda.SourceType, ProjectID: "test-proj"},
cfg: cloudgda.Config{Name: "test-gda", Kind: cloudgda.SourceKind, ProjectID: "test-proj"},
wantClientOAuth: false,
},
{
desc: "initialize with client OAuth",
cfg: cloudgda.Config{Name: "test-gda-oauth", Type: cloudgda.SourceType, ProjectID: "test-proj", UseClientOAuth: true},
cfg: cloudgda.Config{Name: "test-gda-oauth", Kind: cloudgda.SourceKind, ProjectID: "test-proj", UseClientOAuth: true},
wantClientOAuth: true,
},
}

View File

@@ -34,7 +34,7 @@ import (
"google.golang.org/api/option"
)
const SourceType string = "cloud-healthcare"
const SourceKind string = "cloud-healthcare"
// validate interface
var _ sources.SourceConfig = Config{}
@@ -42,8 +42,8 @@ var _ sources.SourceConfig = Config{}
type HealthcareServiceCreator func(tokenString string) (*healthcare.Service, error)
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -58,7 +58,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
// Healthcare configs
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Region string `yaml:"region" validate:"required"`
Dataset string `yaml:"dataset" validate:"required"`
@@ -67,8 +67,8 @@ type Config struct {
UseClientOAuth bool `yaml:"useClientOAuth"`
}
func (c Config) SourceConfigType() string {
return SourceType
func (c Config) SourceConfigKind() string {
return SourceKind
}
func (c Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -144,7 +144,7 @@ func newHealthcareServiceCreator(ctx context.Context, tracer trace.Tracer, name
}
func initHealthcareConnectionWithOAuthToken(ctx context.Context, tracer trace.Tracer, name string, userAgent string, tokenString string) (*healthcare.Service, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Construct token source
token := &oauth2.Token{
@@ -162,7 +162,7 @@ func initHealthcareConnectionWithOAuthToken(ctx context.Context, tracer trace.Tr
}
func initHealthcareConnection(ctx context.Context, tracer trace.Tracer, name string) (*healthcare.Service, oauth2.TokenSource, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
cred, err := google.FindDefaultCredentials(ctx, healthcare.CloudHealthcareScope)
@@ -194,8 +194,8 @@ type Source struct {
allowedDICOMStores map[string]struct{}
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -517,14 +517,14 @@ func (s *Source) RetrieveRenderedDICOMInstance(storeID, study, series, sop strin
return base64String, nil
}
func (s *Source) SearchDICOM(toolType, storeID, dicomWebPath, tokenStr string, opts []googleapi.CallOption) (any, error) {
func (s *Source) SearchDICOM(toolKind, storeID, dicomWebPath, tokenStr string, opts []googleapi.CallOption) (any, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
var resp *http.Response
switch toolType {
switch toolKind {
case "cloud-healthcare-search-dicom-instances":
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...)
case "cloud-healthcare-search-dicom-series":
@@ -532,7 +532,7 @@ func (s *Source) SearchDICOM(toolType, storeID, dicomWebPath, tokenStr string, o
case "cloud-healthcare-search-dicom-studies":
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForStudies(name, dicomWebPath).Do(opts...)
default:
return nil, fmt.Errorf("incompatible tool type: %s", toolType)
return nil, fmt.Errorf("incompatible tool kind: %s", toolKind)
}
if err != nil {
return nil, fmt.Errorf("failed to search dicom series: %w", err)

View File

@@ -15,12 +15,11 @@
package cloudhealthcare_test
import (
"context"
"testing"
"github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -34,17 +33,17 @@ func TestParseFromYamlCloudHealthcare(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-instance
type: cloud-healthcare
project: my-project
region: us-central1
dataset: my-dataset
sources:
my-instance:
kind: cloud-healthcare
project: my-project
region: us-central1
dataset: my-dataset
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-instance": cloudhealthcare.Config{
Name: "my-instance",
Type: cloudhealthcare.SourceType,
Kind: cloudhealthcare.SourceKind,
Project: "my-project",
Region: "us-central1",
Dataset: "my-dataset",
@@ -55,18 +54,18 @@ func TestParseFromYamlCloudHealthcare(t *testing.T) {
{
desc: "use client auth example",
in: `
kind: sources
name: my-instance
type: cloud-healthcare
project: my-project
region: us
dataset: my-dataset
useClientOAuth: true
sources:
my-instance:
kind: cloud-healthcare
project: my-project
region: us
dataset: my-dataset
useClientOAuth: true
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-instance": cloudhealthcare.Config{
Name: "my-instance",
Type: cloudhealthcare.SourceType,
Kind: cloudhealthcare.SourceKind,
Project: "my-project",
Region: "us",
Dataset: "my-dataset",
@@ -77,22 +76,22 @@ func TestParseFromYamlCloudHealthcare(t *testing.T) {
{
desc: "with allowed stores example",
in: `
kind: sources
name: my-instance
type: cloud-healthcare
project: my-project
region: us
dataset: my-dataset
allowedFhirStores:
- my-fhir-store
allowedDicomStores:
- my-dicom-store1
- my-dicom-store2
sources:
my-instance:
kind: cloud-healthcare
project: my-project
region: us
dataset: my-dataset
allowedFhirStores:
- my-fhir-store
allowedDicomStores:
- my-dicom-store1
- my-dicom-store2
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-instance": cloudhealthcare.Config{
Name: "my-instance",
Type: cloudhealthcare.SourceType,
Kind: cloudhealthcare.SourceKind,
Project: "my-project",
Region: "us",
Dataset: "my-dataset",
@@ -104,12 +103,16 @@ func TestParseFromYamlCloudHealthcare(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -124,31 +127,35 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-instance
type: cloud-healthcare
project: my-project
region: us-central1
dataset: my-dataset
foo: bar
sources:
my-instance:
kind: cloud-healthcare
project: my-project
region: us-central1
dataset: my-dataset
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"cloud-healthcare\": [2:1] unknown field \"foo\"\n 1 | dataset: my-dataset\n> 2 | foo: bar\n ^\n 3 | name: my-instance\n 4 | project: my-project\n 5 | region: us-central1\n 6 | ",
err: "unable to parse source \"my-instance\" as \"cloud-healthcare\": [2:1] unknown field \"foo\"\n 1 | dataset: my-dataset\n> 2 | foo: bar\n ^\n 3 | kind: cloud-healthcare\n 4 | project: my-project\n 5 | region: us-central1",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-instance
type: cloud-healthcare
project: my-project
region: us-central1
sources:
my-instance:
kind: cloud-healthcare
project: my-project
region: us-central1
`,
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"cloud-healthcare\": Key: 'Config.Dataset' Error:Field validation for 'Dataset' failed on the 'required' tag",
err: `unable to parse source "my-instance" as "cloud-healthcare": Key: 'Config.Dataset' Error:Field validation for 'Dataset' failed on the 'required' tag`,
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -29,14 +29,14 @@ import (
monitoring "google.golang.org/api/monitoring/v3"
)
const SourceType string = "cloud-monitoring"
const SourceKind string = "cloud-monitoring"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -50,12 +50,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
UseClientOAuth bool `yaml:"useClientOAuth"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
// Initialize initializes a Cloud Monitoring Source instance.
@@ -99,8 +99,8 @@ type Source struct {
userAgent string
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -15,9 +15,9 @@
package cloudmonitoring_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -35,14 +35,14 @@ func TestParseFromYamlCloudMonitoring(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-cloud-monitoring-instance
type: cloud-monitoring
sources:
my-cloud-monitoring-instance:
kind: cloud-monitoring
`,
want: map[string]sources.SourceConfig{
"my-cloud-monitoring-instance": cloudmonitoring.Config{
Name: "my-cloud-monitoring-instance",
Type: cloudmonitoring.SourceType,
Kind: cloudmonitoring.SourceKind,
UseClientOAuth: false,
},
},
@@ -50,15 +50,15 @@ func TestParseFromYamlCloudMonitoring(t *testing.T) {
{
desc: "use client auth example",
in: `
kind: sources
name: my-cloud-monitoring-instance
type: cloud-monitoring
useClientOAuth: true
sources:
my-cloud-monitoring-instance:
kind: cloud-monitoring
useClientOAuth: true
`,
want: map[string]sources.SourceConfig{
"my-cloud-monitoring-instance": cloudmonitoring.Config{
Name: "my-cloud-monitoring-instance",
Type: cloudmonitoring.SourceType,
Kind: cloudmonitoring.SourceKind,
UseClientOAuth: true,
},
},
@@ -68,12 +68,16 @@ func TestParseFromYamlCloudMonitoring(t *testing.T) {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -89,28 +93,36 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-cloud-monitoring-instance
type: cloud-monitoring
project: test-project
sources:
my-cloud-monitoring-instance:
kind: cloud-monitoring
project: test-project
`,
err: "error unmarshaling sources: unable to parse source \"my-cloud-monitoring-instance\" as \"cloud-monitoring\": [2:1] unknown field \"project\"\n 1 | name: my-cloud-monitoring-instance\n> 2 | project: test-project\n ^\n 3 | type: cloud-monitoring",
err: `unable to parse source "my-cloud-monitoring-instance" as "cloud-monitoring": [2:1] unknown field "project"
1 | kind: cloud-monitoring
> 2 | project: test-project
^
`,
},
{
desc: "missing required field",
in: `
kind: sources
name: my-cloud-monitoring-instance
useClientOAuth: true
sources:
my-cloud-monitoring-instance:
useClientOAuth: true
`,
err: "error unmarshaling sources: missing 'type' field or it is not a string",
err: "missing 'kind' field for source \"my-cloud-monitoring-instance\"",
},
}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -34,7 +34,7 @@ import (
sqladmin "google.golang.org/api/sqladmin/v1"
)
const SourceType string = "cloud-sql-admin"
const SourceKind string = "cloud-sql-admin"
var targetLinkRegex = regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/databases/([^/]+)`)
@@ -42,8 +42,8 @@ var targetLinkRegex = regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/da
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -57,13 +57,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
DefaultProject string `yaml:"defaultProject"`
UseClientOAuth bool `yaml:"useClientOAuth"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
// Initialize initializes a CloudSQL Admin Source instance.
@@ -110,8 +110,8 @@ type Source struct {
Service *sqladmin.Service
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -352,28 +352,6 @@ func (s *Source) GetWaitForOperations(ctx context.Context, service *sqladmin.Ser
return nil, nil
}
func (s *Source) InsertBackupRun(ctx context.Context, project, instance, location, backupDescription, accessToken string) (any, error) {
backupRun := &sqladmin.BackupRun{}
if location != "" {
backupRun.Location = location
}
if backupDescription != "" {
backupRun.Description = backupDescription
}
service, err := s.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
resp, err := service.BackupRuns.Insert(project, instance, backupRun).Do()
if err != nil {
return nil, fmt.Errorf("error creating backup: %w", err)
}
return resp, nil
}
func generateCloudSQLConnectionMessage(ctx context.Context, source *Source, logger log.Logger, opResponse map[string]any, connectionMessageTemplate string) (string, bool) {
operationType, ok := opResponse["operationType"].(string)
if !ok || operationType != "CREATE_DATABASE" {

View File

@@ -15,9 +15,9 @@
package cloudsqladmin_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -35,14 +35,14 @@ func TestParseFromYamlCloudSQLAdmin(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-cloud-sql-admin-instance
type: cloud-sql-admin
sources:
my-cloud-sql-admin-instance:
kind: cloud-sql-admin
`,
want: map[string]sources.SourceConfig{
"my-cloud-sql-admin-instance": cloudsqladmin.Config{
Name: "my-cloud-sql-admin-instance",
Type: cloudsqladmin.SourceType,
Kind: cloudsqladmin.SourceKind,
UseClientOAuth: false,
},
},
@@ -50,15 +50,15 @@ func TestParseFromYamlCloudSQLAdmin(t *testing.T) {
{
desc: "use client auth example",
in: `
kind: sources
name: my-cloud-sql-admin-instance
type: cloud-sql-admin
useClientOAuth: true
sources:
my-cloud-sql-admin-instance:
kind: cloud-sql-admin
useClientOAuth: true
`,
want: map[string]sources.SourceConfig{
"my-cloud-sql-admin-instance": cloudsqladmin.Config{
Name: "my-cloud-sql-admin-instance",
Type: cloudsqladmin.SourceType,
Kind: cloudsqladmin.SourceKind,
UseClientOAuth: true,
},
},
@@ -68,12 +68,16 @@ func TestParseFromYamlCloudSQLAdmin(t *testing.T) {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -89,28 +93,36 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-cloud-sql-admin-instance
type: cloud-sql-admin
project: test-project
sources:
my-cloud-sql-admin-instance:
kind: cloud-sql-admin
project: test-project
`,
err: "error unmarshaling sources: unable to parse source \"my-cloud-sql-admin-instance\" as \"cloud-sql-admin\": [2:1] unknown field \"project\"\n 1 | name: my-cloud-sql-admin-instance\n> 2 | project: test-project\n ^\n 3 | type: cloud-sql-admin",
err: `unable to parse source "my-cloud-sql-admin-instance" as "cloud-sql-admin": [2:1] unknown field "project"
1 | kind: cloud-sql-admin
> 2 | project: test-project
^
`,
},
{
desc: "missing required field",
in: `
kind: sources
name: my-cloud-sql-admin-instance
useClientOAuth: true
sources:
my-cloud-sql-admin-instance:
useClientOAuth: true
`,
err: "error unmarshaling sources: missing 'type' field or it is not a string",
err: "missing 'kind' field for source \"my-cloud-sql-admin-instance\"",
},
}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -29,14 +29,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "cloud-sql-mssql"
const SourceKind string = "cloud-sql-mssql"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -51,7 +51,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
// Cloud SQL MSSQL configs
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Region string `yaml:"region" validate:"required"`
Instance string `yaml:"instance" validate:"required"`
@@ -62,9 +62,9 @@ type Config struct {
Database string `yaml:"database" validate:"required"`
}
func (r Config) SourceConfigType() string {
// Returns Cloud SQL MSSQL source type
return SourceType
func (r Config) SourceConfigKind() string {
// Returns Cloud SQL MSSQL source kind
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -94,9 +94,9 @@ type Source struct {
Db *sql.DB
}
func (s *Source) SourceType() string {
// Returns Cloud SQL MSSQL source type
return SourceType
func (s *Source) SourceKind() string {
// Returns Cloud SQL MSSQL source kind
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -152,7 +152,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
userAgent, err := util.UserAgentFromContext(ctx)

View File

@@ -15,12 +15,11 @@
package cloudsqlmssql_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -34,20 +33,20 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-instance
type: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
sources:
my-instance:
kind: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-instance": cloudsqlmssql.Config{
Name: "my-instance",
Type: cloudsqlmssql.SourceType,
Kind: cloudsqlmssql.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -61,21 +60,21 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
{
desc: "psc ipType",
in: `
kind: sources
name: my-instance
type: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
ipType: psc
sources:
my-instance:
kind: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
ipType: psc
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-instance": cloudsqlmssql.Config{
Name: "my-instance",
Type: cloudsqlmssql.SourceType,
Kind: cloudsqlmssql.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -89,21 +88,21 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
{
desc: "with deprecated ipAddress",
in: `
kind: sources
name: my-instance
type: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
ipAddress: random
database: my_db
user: my_user
password: my_pass
sources:
my-instance:
kind: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
ipAddress: random
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-instance": cloudsqlmssql.Config{
Name: "my-instance",
Type: cloudsqlmssql.SourceType,
Kind: cloudsqlmssql.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -118,12 +117,16 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect psarse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect psarse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -139,53 +142,57 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "invalid ipType",
in: `
kind: sources
name: my-instance
type: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
ipType: fail
database: my_db
user: my_user
password: my_pass
sources:
my-instance:
kind: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
ipType: fail
database: my_db
user: my_user
password: my_pass
`,
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"cloud-sql-mssql\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
err: "unable to parse source \"my-instance\" as \"cloud-sql-mssql\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
},
{
desc: "extra field",
in: `
kind: sources
name: my-instance
type: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
foo: bar
sources:
my-instance:
kind: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"cloud-sql-mssql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | name: my-instance\n 5 | password: my_pass\n 6 | ",
err: "unable to parse source \"my-instance\" as \"cloud-sql-mssql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | kind: cloud-sql-mssql\n 5 | password: my_pass\n 6 | ",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-instance
type: cloud-sql-mssql
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
sources:
my-instance:
kind: cloud-sql-mssql
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
`,
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"cloud-sql-mssql\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
err: "unable to parse source \"my-instance\" as \"cloud-sql-mssql\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -30,14 +30,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "cloud-sql-mysql"
const SourceKind string = "cloud-sql-mysql"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -51,7 +51,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Region string `yaml:"region" validate:"required"`
Instance string `yaml:"instance" validate:"required"`
@@ -61,8 +61,8 @@ type Config struct {
Database string `yaml:"database" validate:"required"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -90,8 +90,8 @@ type Source struct {
Pool *sql.DB
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -184,7 +184,7 @@ func getConnectionConfig(ctx context.Context, user, pass string) (string, string
func initCloudSQLMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Configure the driver to connect to the database

View File

@@ -15,12 +15,11 @@
package cloudsqlmysql_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -34,20 +33,20 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-mysql-instance
type: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
sources:
my-mysql-instance:
kind: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-mysql-instance": cloudsqlmysql.Config{
Name: "my-mysql-instance",
Type: cloudsqlmysql.SourceType,
Kind: cloudsqlmysql.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -61,21 +60,21 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
{
desc: "public ipType",
in: `
kind: sources
name: my-mysql-instance
type: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
ipType: Public
database: my_db
user: my_user
password: my_pass
sources:
my-mysql-instance:
kind: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
ipType: Public
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-mysql-instance": cloudsqlmysql.Config{
Name: "my-mysql-instance",
Type: cloudsqlmysql.SourceType,
Kind: cloudsqlmysql.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -89,21 +88,21 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
{
desc: "private ipType",
in: `
kind: sources
name: my-mysql-instance
type: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
ipType: private
database: my_db
user: my_user
password: my_pass
sources:
my-mysql-instance:
kind: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
ipType: private
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-mysql-instance": cloudsqlmysql.Config{
Name: "my-mysql-instance",
Type: cloudsqlmysql.SourceType,
Kind: cloudsqlmysql.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -117,21 +116,21 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
{
desc: "psc ipType",
in: `
kind: sources
name: my-mysql-instance
type: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
ipType: psc
database: my_db
user: my_user
password: my_pass
sources:
my-mysql-instance:
kind: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
ipType: psc
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-mysql-instance": cloudsqlmysql.Config{
Name: "my-mysql-instance",
Type: cloudsqlmysql.SourceType,
Kind: cloudsqlmysql.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -145,12 +144,16 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -166,53 +169,57 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "invalid ipType",
in: `
kind: sources
name: my-mysql-instance
type: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
ipType: fail
database: my_db
user: my_user
password: my_pass
sources:
my-mysql-instance:
kind: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
ipType: fail
database: my_db
user: my_user
password: my_pass
`,
err: "error unmarshaling sources: unable to parse source \"my-mysql-instance\" as \"cloud-sql-mysql\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
err: "unable to parse source \"my-mysql-instance\" as \"cloud-sql-mysql\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
},
{
desc: "extra field",
in: `
kind: sources
name: my-mysql-instance
type: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
foo: bar
sources:
my-mysql-instance:
kind: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-mysql-instance\" as \"cloud-sql-mysql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | name: my-mysql-instance\n 5 | password: my_pass\n 6 | ",
err: "unable to parse source \"my-mysql-instance\" as \"cloud-sql-mysql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | kind: cloud-sql-mysql\n 5 | password: my_pass\n 6 | ",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-mysql-instance
type: cloud-sql-mysql
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
sources:
my-mysql-instance:
kind: cloud-sql-mysql
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
`,
err: "error unmarshaling sources: unable to parse source \"my-mysql-instance\" as \"cloud-sql-mysql\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
err: "unable to parse source \"my-mysql-instance\" as \"cloud-sql-mysql\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -28,14 +28,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "cloud-sql-postgres"
const SourceKind string = "cloud-sql-postgres"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -49,7 +49,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Region string `yaml:"region" validate:"required"`
Instance string `yaml:"instance" validate:"required"`
@@ -59,8 +59,8 @@ type Config struct {
Password string `yaml:"password"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -88,8 +88,8 @@ type Source struct {
Pool *pgxpool.Pool
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -162,7 +162,7 @@ func getConnectionConfig(ctx context.Context, user, pass, dbname string) (string
func initCloudSQLPgConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Configure the driver to connect to the database

View File

@@ -15,12 +15,11 @@
package cloudsqlpg_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -34,20 +33,20 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-pg-instance": cloudsqlpg.Config{
Name: "my-pg-instance",
Type: cloudsqlpg.SourceType,
Kind: cloudsqlpg.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -61,21 +60,21 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
{
desc: "public ipType",
in: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
ipType: Public
database: my_db
user: my_user
password: my_pass
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
ipType: Public
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-pg-instance": cloudsqlpg.Config{
Name: "my-pg-instance",
Type: cloudsqlpg.SourceType,
Kind: cloudsqlpg.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -89,21 +88,21 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
{
desc: "private ipType",
in: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
ipType: private
database: my_db
user: my_user
password: my_pass
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
ipType: private
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-pg-instance": cloudsqlpg.Config{
Name: "my-pg-instance",
Type: cloudsqlpg.SourceType,
Kind: cloudsqlpg.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -117,21 +116,21 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
{
desc: "psc ipType",
in: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
ipType: psc
database: my_db
user: my_user
password: my_pass
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
ipType: psc
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-pg-instance": cloudsqlpg.Config{
Name: "my-pg-instance",
Type: cloudsqlpg.SourceType,
Kind: cloudsqlpg.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -145,12 +144,16 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -166,53 +169,57 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "invalid ipType",
in: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
ipType: fail
database: my_db
user: my_user
password: my_pass
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
ipType: fail
database: my_db
user: my_user
password: my_pass
`,
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"cloud-sql-postgres\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
err: "unable to parse source \"my-pg-instance\" as \"cloud-sql-postgres\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
},
{
desc: "extra field",
in: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
foo: bar
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"cloud-sql-postgres\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | name: my-pg-instance\n 5 | password: my_pass\n 6 | ",
err: "unable to parse source \"my-pg-instance\" as \"cloud-sql-postgres\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | kind: cloud-sql-postgres\n 5 | password: my_pass\n 6 | ",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
sources:
my-pg-instance:
kind: cloud-sql-postgres
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
`,
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"cloud-sql-postgres\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
err: "unable to parse source \"my-pg-instance\" as \"cloud-sql-postgres\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -29,14 +29,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "couchbase"
const SourceKind string = "couchbase"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -50,7 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
ConnectionString string `yaml:"connectionString" validate:"required"`
Bucket string `yaml:"bucket" validate:"required"`
Scope string `yaml:"scope" validate:"required"`
@@ -66,8 +66,8 @@ type Config struct {
QueryScanConsistency uint `yaml:"queryScanConsistency"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -96,8 +96,8 @@ type Source struct {
Scope *gocb.Scope
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -15,12 +15,11 @@
package couchbase_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/couchbase"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -34,19 +33,19 @@ func TestParseFromYamlCouchbase(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-couchbase-instance
type: couchbase
connectionString: localhost
username: Administrator
password: password
bucket: travel-sample
scope: inventory
sources:
my-couchbase-instance:
kind: couchbase
connectionString: localhost
username: Administrator
password: password
bucket: travel-sample
scope: inventory
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-couchbase-instance": couchbase.Config{
Name: "my-couchbase-instance",
Type: couchbase.SourceType,
Kind: couchbase.SourceKind,
ConnectionString: "localhost",
Username: "Administrator",
Password: "password",
@@ -58,24 +57,24 @@ func TestParseFromYamlCouchbase(t *testing.T) {
{
desc: "with TLS configuration",
in: `
kind: sources
name: my-couchbase-instance
type: couchbase
connectionString: couchbases://localhost
bucket: travel-sample
scope: inventory
clientCert: /path/to/cert.pem
clientKey: /path/to/key.pem
clientCertPassword: password
clientKeyPassword: password
caCert: /path/to/ca.pem
noSslVerify: false
queryScanConsistency: 2
sources:
my-couchbase-instance:
kind: couchbase
connectionString: couchbases://localhost
bucket: travel-sample
scope: inventory
clientCert: /path/to/cert.pem
clientKey: /path/to/key.pem
clientCertPassword: password
clientKeyPassword: password
caCert: /path/to/ca.pem
noSslVerify: false
queryScanConsistency: 2
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-couchbase-instance": couchbase.Config{
Name: "my-couchbase-instance",
Type: couchbase.SourceType,
Kind: couchbase.SourceKind,
ConnectionString: "couchbases://localhost",
Bucket: "travel-sample",
Scope: "inventory",
@@ -92,12 +91,16 @@ func TestParseFromYamlCouchbase(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -112,35 +115,39 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-couchbase-instance
type: couchbase
connectionString: localhost
username: Administrator
password: password
bucket: travel-sample
scope: inventory
foo: bar
sources:
my-couchbase-instance:
kind: couchbase
connectionString: localhost
username: Administrator
password: password
bucket: travel-sample
scope: inventory
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-couchbase-instance\" as \"couchbase\": [3:1] unknown field \"foo\"\n 1 | bucket: travel-sample\n 2 | connectionString: localhost\n> 3 | foo: bar\n ^\n 4 | name: my-couchbase-instance\n 5 | password: password\n 6 | scope: inventory\n 7 | ",
err: "unable to parse source \"my-couchbase-instance\" as \"couchbase\": [3:1] unknown field \"foo\"\n 1 | bucket: travel-sample\n 2 | connectionString: localhost\n> 3 | foo: bar\n ^\n 4 | kind: couchbase\n 5 | password: password\n 6 | scope: inventory\n 7 | ",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-couchbase-instance
type: couchbase
username: Administrator
password: password
bucket: travel-sample
scope: inventory
sources:
my-couchbase-instance:
kind: couchbase
username: Administrator
password: password
bucket: travel-sample
scope: inventory
`,
err: "error unmarshaling sources: unable to parse source \"my-couchbase-instance\" as \"couchbase\": Key: 'Config.ConnectionString' Error:Field validation for 'ConnectionString' failed on the 'required' tag",
err: "unable to parse source \"my-couchbase-instance\" as \"couchbase\": Key: 'Config.ConnectionString' Error:Field validation for 'ConnectionString' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -29,14 +29,14 @@ import (
"google.golang.org/api/option"
)
const SourceType string = "dataplex"
const SourceKind string = "dataplex"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -51,13 +51,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
// Dataplex configs
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
}
func (r Config) SourceConfigType() string {
// Returns Dataplex source type
return SourceType
func (r Config) SourceConfigKind() string {
// Returns Dataplex source kind
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -81,9 +81,9 @@ type Source struct {
Client *dataplexapi.CatalogClient
}
func (s *Source) SourceType() string {
// Returns Dataplex source type
return SourceType
func (s *Source) SourceKind() string {
// Returns Dataplex source kind
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -104,7 +104,7 @@ func initDataplexConnection(
name string,
project string,
) (*dataplexapi.CatalogClient, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
cred, err := google.FindDefaultCredentials(ctx)

View File

@@ -15,12 +15,11 @@
package dataplex_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/dataplex"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -34,15 +33,15 @@ func TestParseFromYamlDataplex(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-instance
type: dataplex
project: my-project
sources:
my-instance:
kind: dataplex
project: my-project
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-instance": dataplex.Config{
Name: "my-instance",
Type: dataplex.SourceType,
Kind: dataplex.SourceKind,
Project: "my-project",
},
},
@@ -50,12 +49,16 @@ func TestParseFromYamlDataplex(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -71,27 +74,31 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-instance
type: dataplex
project: my-project
foo: bar
sources:
my-instance:
kind: dataplex
project: my-project
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"dataplex\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | name: my-instance\n 3 | project: my-project\n 4 | type: dataplex",
err: "unable to parse source \"my-instance\" as \"dataplex\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | kind: dataplex\n 3 | project: my-project",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-instance
type: dataplex
sources:
my-instance:
kind: dataplex
`,
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"dataplex\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
err: "unable to parse source \"my-instance\" as \"dataplex\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -30,14 +30,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "dgraph"
const SourceKind string = "dgraph"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -67,7 +67,7 @@ type DgraphClient struct {
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
DgraphUrl string `yaml:"dgraphUrl" validate:"required"`
User string `yaml:"user"`
Password string `yaml:"password"`
@@ -75,8 +75,8 @@ type Config struct {
ApiKey string `yaml:"apiKey"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -103,8 +103,8 @@ type Source struct {
Client *DgraphClient `yaml:"client"`
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -139,7 +139,7 @@ func (s *Source) RunSQL(statement string, params parameters.ParamValues, isQuery
func initDgraphHttpClient(ctx context.Context, tracer trace.Tracer, r Config) (*DgraphClient, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, r.Name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, r.Name)
defer span.End()
if r.DgraphUrl == "" {

View File

@@ -15,12 +15,11 @@
package dgraph_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/dgraph"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -34,19 +33,19 @@ func TestParseFromYamlDgraph(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-dgraph-instance
type: dgraph
dgraphUrl: https://localhost:8080
apiKey: abc123
password: pass@123
namespace: 0
user: user123
sources:
my-dgraph-instance:
kind: dgraph
dgraphUrl: https://localhost:8080
apiKey: abc123
password: pass@123
namespace: 0
user: user123
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-dgraph-instance": dgraph.Config{
Name: "my-dgraph-instance",
Type: dgraph.SourceType,
Kind: dgraph.SourceKind,
DgraphUrl: "https://localhost:8080",
ApiKey: "abc123",
Password: "pass@123",
@@ -58,15 +57,15 @@ func TestParseFromYamlDgraph(t *testing.T) {
{
desc: "basic example minimal field",
in: `
kind: sources
name: my-dgraph-instance
type: dgraph
dgraphUrl: https://localhost:8080
sources:
my-dgraph-instance:
kind: dgraph
dgraphUrl: https://localhost:8080
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-dgraph-instance": dgraph.Config{
Name: "my-dgraph-instance",
Type: dgraph.SourceType,
Kind: dgraph.SourceKind,
DgraphUrl: "https://localhost:8080",
},
},
@@ -75,12 +74,16 @@ func TestParseFromYamlDgraph(t *testing.T) {
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got); diff != "" {
if diff := cmp.Diff(tc.want, got.Sources); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
@@ -97,27 +100,31 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-dgraph-instance
type: dgraph
dgraphUrl: https://localhost:8080
foo: bar
sources:
my-dgraph-instance:
kind: dgraph
dgraphUrl: https://localhost:8080
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-dgraph-instance\" as \"dgraph\": [2:1] unknown field \"foo\"\n 1 | dgraphUrl: https://localhost:8080\n> 2 | foo: bar\n ^\n 3 | name: my-dgraph-instance\n 4 | type: dgraph",
err: "unable to parse source \"my-dgraph-instance\" as \"dgraph\": [2:1] unknown field \"foo\"\n 1 | dgraphUrl: https://localhost:8080\n> 2 | foo: bar\n ^\n 3 | kind: dgraph",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-dgraph-instance
type: dgraph
sources:
my-dgraph-instance:
kind: dgraph
`,
err: "error unmarshaling sources: unable to parse source \"my-dgraph-instance\" as \"dgraph\": Key: 'Config.DgraphUrl' Error:Field validation for 'DgraphUrl' failed on the 'required' tag",
err: "unable to parse source \"my-dgraph-instance\" as \"dgraph\": Key: 'Config.DgraphUrl' Error:Field validation for 'DgraphUrl' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -30,14 +30,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "elasticsearch"
const SourceKind string = "elasticsearch"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -51,15 +51,15 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Addresses []string `yaml:"addresses" validate:"required"`
Username string `yaml:"username"`
Password string `yaml:"password"`
APIKey string `yaml:"apikey"`
}
func (c Config) SourceConfigType() string {
return SourceType
func (c Config) SourceConfigKind() string {
return SourceKind
}
type EsClient interface {
@@ -139,9 +139,9 @@ func (c Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return s, nil
}
// SourceType returns the resourceType string for this source.
func (s *Source) SourceType() string {
return SourceType
// SourceKind returns the kind string for this source.
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -15,15 +15,13 @@
package elasticsearch_test
import (
"context"
"reflect"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/elasticsearch"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
func TestParseFromYamlElasticsearch(t *testing.T) {
@@ -35,17 +33,17 @@ func TestParseFromYamlElasticsearch(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-es-instance
type: elasticsearch
addresses:
- http://localhost:9200
apikey: somekey
`,
want: map[string]sources.SourceConfig{
sources:
my-es-instance:
kind: elasticsearch
addresses:
- http://localhost:9200
apikey: somekey
`,
want: server.SourceConfigs{
"my-es-instance": elasticsearch.Config{
Name: "my-es-instance",
Type: elasticsearch.SourceType,
Kind: elasticsearch.SourceKind,
Addresses: []string{"http://localhost:9200"},
APIKey: "somekey",
},
@@ -54,50 +52,20 @@ func TestParseFromYamlElasticsearch(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
err := yaml.Unmarshal([]byte(tc.in), &got)
if err != nil {
t.Fatalf("failed to parse yaml: %v", err)
}
if diff := cmp.Diff(tc.want, got); diff != "" {
if diff := cmp.Diff(tc.want, got.Sources); diff != "" {
t.Errorf("unexpected config diff (-want +got):\n%s", diff)
}
})
}
}
func TestFailParseFromYaml(t *testing.T) {
tcs := []struct {
desc string
in string
err string
}{
{
desc: "extra field",
in: `
kind: sources
name: my-es-instance
type: elasticsearch
addresses:
- http://localhost:9200
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-es-instance\" as \"elasticsearch\": [3:1] unknown field \"foo\"\n 1 | addresses:\n 2 | - http://localhost:9200\n> 3 | foo: bar\n ^\n 4 | name: my-es-instance\n 5 | type: elasticsearch",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}
errStr := err.Error()
if errStr != tc.err {
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
}
})
}
}
func TestTool_esqlToMap(t1 *testing.T) {
tests := []struct {
name string

View File

@@ -27,13 +27,13 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
)
const SourceType string = "firebird"
const SourceKind string = "firebird"
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -47,7 +47,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
@@ -55,8 +55,8 @@ type Config struct {
Database string `yaml:"database" validate:"required"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -84,8 +84,8 @@ type Source struct {
Db *sql.DB
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -144,7 +144,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
}
func initFirebirdConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string) (*sql.DB, error) {
_, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
_, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// urlExample := "user:password@host:port/path/to/database.fdb"

View File

@@ -15,12 +15,11 @@
package firebird_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/firebird"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -34,19 +33,19 @@ func TestParseFromYamlFirebird(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-fdb-instance
type: firebird
host: my-host
port: my-port
database: my_db
user: my_user
password: my_pass
sources:
my-fdb-instance:
kind: firebird
host: my-host
port: my-port
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-fdb-instance": firebird.Config{
Name: "my-fdb-instance",
Type: firebird.SourceType,
Kind: firebird.SourceKind,
Host: "my-host",
Port: "my-port",
Database: "my_db",
@@ -58,12 +57,16 @@ func TestParseFromYamlFirebird(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -79,35 +82,39 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-fdb-instance
type: firebird
host: my-host
port: my-port
database: my_db
user: my_user
password: my_pass
foo: bar
sources:
my-fdb-instance:
kind: firebird
host: my-host
port: my-port
database: my_db
user: my_user
password: my_pass
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-fdb-instance\" as \"firebird\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: my-host\n 4 | name: my-fdb-instance\n 5 | password: my_pass\n 6 | ",
err: "unable to parse source \"my-fdb-instance\" as \"firebird\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: my-host\n 4 | kind: firebird\n 5 | password: my_pass\n 6 | ",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-fdb-instance
type: firebird
host: my-host
port: my-port
database: my_db
user: my_user
sources:
my-fdb-instance:
kind: firebird
host: my-host
port: my-port
database: my_db
user: my_user
`,
err: "error unmarshaling sources: unable to parse source \"my-fdb-instance\" as \"firebird\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
err: "unable to parse source \"my-fdb-instance\" as \"firebird\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -31,14 +31,14 @@ import (
"google.golang.org/genproto/googleapis/type/latlng"
)
const SourceType string = "firestore"
const SourceKind string = "firestore"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -53,14 +53,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
// Firestore configs
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Database string `yaml:"database"` // Optional, defaults to "(default)"
}
func (r Config) SourceConfigType() string {
// Returns Firestore source type
return SourceType
func (r Config) SourceConfigKind() string {
// Returns Firestore source kind
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -92,9 +92,9 @@ type Source struct {
RulesClient *firebaserules.Service
}
func (s *Source) SourceType() string {
// Returns Firestore source type
return SourceType
func (s *Source) SourceKind() string {
// Returns Firestore source kind
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -594,7 +594,7 @@ func initFirestoreConnection(
project string,
database string,
) (*firestore.Client, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
userAgent, err := util.UserAgentFromContext(ctx)

View File

@@ -15,13 +15,12 @@
package firestore_test
import (
"context"
"testing"
"time"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/firestore"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -35,15 +34,15 @@ func TestParseFromYamlFirestore(t *testing.T) {
{
desc: "basic example with default database",
in: `
kind: sources
name: my-firestore
type: firestore
project: my-project
sources:
my-firestore:
kind: firestore
project: my-project
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-firestore": firestore.Config{
Name: "my-firestore",
Type: firestore.SourceType,
Kind: firestore.SourceKind,
Project: "my-project",
Database: "",
},
@@ -52,16 +51,16 @@ func TestParseFromYamlFirestore(t *testing.T) {
{
desc: "with custom database",
in: `
kind: sources
name: my-firestore
type: firestore
project: my-project
database: my-database
sources:
my-firestore:
kind: firestore
project: my-project
database: my-database
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-firestore": firestore.Config{
Name: "my-firestore",
Type: firestore.SourceType,
Kind: firestore.SourceKind,
Project: "my-project",
Database: "my-database",
},
@@ -70,18 +69,22 @@ func TestParseFromYamlFirestore(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
}
func TestFailParseFromYaml(t *testing.T) {
func TestFailParseFromYamlFirestore(t *testing.T) {
tcs := []struct {
desc string
in string
@@ -90,27 +93,32 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-firestore
type: firestore
project: my-project
foo: bar
sources:
my-firestore:
kind: firestore
project: my-project
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-firestore\" as \"firestore\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | name: my-firestore\n 3 | project: my-project\n 4 | type: firestore",
err: "unable to parse source \"my-firestore\" as \"firestore\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | kind: firestore\n 3 | project: my-project",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-firestore
type: firestore
sources:
my-firestore:
kind: firestore
database: my-database
`,
err: "error unmarshaling sources: unable to parse source \"my-firestore\" as \"firestore\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
err: "unable to parse source \"my-firestore\" as \"firestore\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -29,14 +29,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "http"
const SourceKind string = "http"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -50,7 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
BaseURL string `yaml:"baseUrl"`
Timeout string `yaml:"timeout"`
DefaultHeaders map[string]string `yaml:"headers"`
@@ -58,8 +58,8 @@ type Config struct {
DisableSslVerification bool `yaml:"disableSslVerification"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
// Initialize initializes an HTTP Source instance.
@@ -122,8 +122,8 @@ type Source struct {
client *http.Client
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -15,9 +15,9 @@
package http_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -34,15 +34,15 @@ func TestParseFromYamlHttp(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-http-instance
type: http
baseUrl: http://test_server/
sources:
my-http-instance:
kind: http
baseUrl: http://test_server/
`,
want: map[string]sources.SourceConfig{
"my-http-instance": http.Config{
Name: "my-http-instance",
Type: http.SourceType,
Kind: http.SourceKind,
BaseURL: "http://test_server/",
Timeout: "30s",
DisableSslVerification: false,
@@ -52,23 +52,23 @@ func TestParseFromYamlHttp(t *testing.T) {
{
desc: "advanced example",
in: `
kind: sources
name: my-http-instance
type: http
baseUrl: http://test_server/
timeout: 10s
headers:
Authorization: test_header
Custom-Header: custom
queryParams:
api-key: test_api_key
param: param-value
disableSslVerification: true
sources:
my-http-instance:
kind: http
baseUrl: http://test_server/
timeout: 10s
headers:
Authorization: test_header
Custom-Header: custom
queryParams:
api-key: test_api_key
param: param-value
disableSslVerification: true
`,
want: map[string]sources.SourceConfig{
"my-http-instance": http.Config{
Name: "my-http-instance",
Type: http.SourceType,
Kind: http.SourceKind,
BaseURL: "http://test_server/",
Timeout: "10s",
DefaultHeaders: map[string]string{"Authorization": "test_header", "Custom-Header": "custom"},
@@ -80,12 +80,16 @@ func TestParseFromYamlHttp(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -100,32 +104,36 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-http-instance
type: http
baseUrl: http://test_server/
timeout: 10s
headers:
Authorization: test_header
queryParams:
api-key: test_api_key
project: test-project
sources:
my-http-instance:
kind: http
baseUrl: http://test_server/
timeout: 10s
headers:
Authorization: test_header
queryParams:
api-key: test_api_key
project: test-project
`,
err: "error unmarshaling sources: unable to parse source \"my-http-instance\" as \"http\": [5:1] unknown field \"project\"\n 2 | headers:\n 3 | Authorization: test_header\n 4 | name: my-http-instance\n> 5 | project: test-project\n ^\n 6 | queryParams:\n 7 | api-key: test_api_key\n 8 | timeout: 10s\n 9 | ",
err: "unable to parse source \"my-http-instance\" as \"http\": [5:1] unknown field \"project\"\n 2 | headers:\n 3 | Authorization: test_header\n 4 | kind: http\n> 5 | project: test-project\n ^\n 6 | queryParams:\n 7 | api-key: test_api_key\n 8 | timeout: 10s",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-http-instance
baseUrl: http://test_server/
sources:
my-http-instance:
baseUrl: http://test_server/
`,
err: "error unmarshaling sources: missing 'type' field or it is not a string",
err: "missing 'kind' field for source \"my-http-instance\"",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -33,14 +33,14 @@ import (
v4 "github.com/looker-open-source/sdk-codegen/go/sdk/v4"
)
const SourceType string = "looker"
const SourceKind string = "looker"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -64,7 +64,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
BaseURL string `yaml:"base_url" validate:"required"`
ClientId string `yaml:"client_id"`
ClientSecret string `yaml:"client_secret"`
@@ -79,8 +79,8 @@ type Config struct {
SessionLength int64 `yaml:"sessionLength"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
// Initialize initializes a Looker Source instance.
@@ -154,8 +154,8 @@ type Source struct {
AuthTokenHeaderName string
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -15,9 +15,9 @@
package looker_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -34,17 +34,17 @@ func TestParseFromYamlLooker(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-looker-instance
type: looker
base_url: http://example.looker.com/
client_id: jasdl;k;tjl
client_secret: sdakl;jgflkasdfkfg
sources:
my-looker-instance:
kind: looker
base_url: http://example.looker.com/
client_id: jasdl;k;tjl
client_secret: sdakl;jgflkasdfkfg
`,
want: map[string]sources.SourceConfig{
"my-looker-instance": looker.Config{
Name: "my-looker-instance",
Type: looker.SourceType,
Kind: looker.SourceKind,
BaseURL: "http://example.looker.com/",
ClientId: "jasdl;k;tjl",
ClientSecret: "sdakl;jgflkasdfkfg",
@@ -62,18 +62,22 @@ func TestParseFromYamlLooker(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
}
func TestFailParseFromYaml(t *testing.T) {
func TestFailParseFromYamlLooker(t *testing.T) {
tcs := []struct {
desc string
in string
@@ -82,30 +86,34 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-looker-instance
type: looker
base_url: http://example.looker.com/
client_id: jasdl;k;tjl
client_secret: sdakl;jgflkasdfkfg
schema: test-schema
sources:
my-looker-instance:
kind: looker
base_url: http://example.looker.com/
client_id: jasdl;k;tjl
client_secret: sdakl;jgflkasdfkfg
schema: test-schema
`,
err: "error unmarshaling sources: unable to parse source \"my-looker-instance\" as \"looker\": [5:1] unknown field \"schema\"\n 2 | client_id: jasdl;k;tjl\n 3 | client_secret: sdakl;jgflkasdfkfg\n 4 | name: my-looker-instance\n> 5 | schema: test-schema\n ^\n 6 | type: looker",
err: "unable to parse source \"my-looker-instance\" as \"looker\": [5:1] unknown field \"schema\"\n 2 | client_id: jasdl;k;tjl\n 3 | client_secret: sdakl;jgflkasdfkfg\n 4 | kind: looker\n> 5 | schema: test-schema\n ^\n",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-looker-instance
type: looker
client_id: jasdl;k;tjl
sources:
my-looker-instance:
kind: looker
client_id: jasdl;k;tjl
`,
err: "error unmarshaling sources: unable to parse source \"my-looker-instance\" as \"looker\": Key: 'Config.BaseURL' Error:Field validation for 'BaseURL' failed on the 'required' tag",
err: "unable to parse source \"my-looker-instance\" as \"looker\": Key: 'Config.BaseURL' Error:Field validation for 'BaseURL' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -27,14 +27,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "mindsdb"
const SourceKind string = "mindsdb"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -48,7 +48,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
@@ -57,8 +57,8 @@ type Config struct {
QueryTimeout string `yaml:"queryTimeout"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -86,8 +86,8 @@ type Source struct {
Pool *sql.DB
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -159,7 +159,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
func initMindsDBConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Configure the driver to connect to the database

View File

@@ -15,12 +15,11 @@
package mindsdb_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/mindsdb"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -34,19 +33,19 @@ func TestParseFromYamlMindsDB(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-mindsdb-instance
type: mindsdb
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
sources:
my-mindsdb-instance:
kind: mindsdb
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-mindsdb-instance": mindsdb.Config{
Name: "my-mindsdb-instance",
Type: mindsdb.SourceType,
Kind: mindsdb.SourceKind,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -58,20 +57,20 @@ func TestParseFromYamlMindsDB(t *testing.T) {
{
desc: "with query timeout",
in: `
kind: sources
name: my-mindsdb-instance
type: mindsdb
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
queryTimeout: 45s
sources:
my-mindsdb-instance:
kind: mindsdb
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
queryTimeout: 45s
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-mindsdb-instance": mindsdb.Config{
Name: "my-mindsdb-instance",
Type: mindsdb.SourceType,
Kind: mindsdb.SourceKind,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -84,12 +83,16 @@ func TestParseFromYamlMindsDB(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -105,35 +108,39 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-mindsdb-instance
type: mindsdb
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
foo: bar
sources:
my-mindsdb-instance:
kind: mindsdb
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-mindsdb-instance\" as \"mindsdb\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | name: my-mindsdb-instance\n 5 | password: my_pass\n 6 | ",
err: "unable to parse source \"my-mindsdb-instance\" as \"mindsdb\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | kind: mindsdb\n 5 | password: my_pass\n 6 | ",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-mindsdb-instance
type: mindsdb
port: my-port
database: my_db
user: my_user
password: my_pass
sources:
my-mindsdb-instance:
kind: mindsdb
port: my-port
database: my_db
user: my_user
password: my_pass
`,
err: "error unmarshaling sources: unable to parse source \"my-mindsdb-instance\" as \"mindsdb\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
err: "unable to parse source \"my-mindsdb-instance\" as \"mindsdb\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -29,14 +29,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "mongodb"
const SourceKind string = "mongodb"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -50,12 +50,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Uri string `yaml:"uri" validate:"required"` // MongoDB Atlas connection URI
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -84,8 +84,8 @@ type Source struct {
Client *mongo.Client
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -293,7 +293,7 @@ func (s *Source) DeleteOne(ctx context.Context, filterString, database, collecti
func initMongoDBClient(ctx context.Context, tracer trace.Tracer, name, uri string) (*mongo.Client, error) {
// Start a tracing span
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
userAgent, err := util.UserAgentFromContext(ctx)

View File

@@ -15,12 +15,11 @@
package mongodb_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/mongodb"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -34,15 +33,15 @@ func TestParseFromYamlMongoDB(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: mongo-db
type: "mongodb"
uri: "mongodb+srv://username:password@host/dbname"
sources:
mongo-db:
kind: "mongodb"
uri: "mongodb+srv://username:password@host/dbname"
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"mongo-db": mongodb.Config{
Name: "mongo-db",
Type: mongodb.SourceType,
Kind: mongodb.SourceKind,
Uri: "mongodb+srv://username:password@host/dbname",
},
},
@@ -50,12 +49,16 @@ func TestParseFromYamlMongoDB(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -71,27 +74,31 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: mongo-db
type: mongodb
uri: "mongodb+srv://username:password@host/dbname"
foo: bar
sources:
mongo-db:
kind: mongodb
uri: "mongodb+srv://username:password@host/dbname"
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"mongo-db\" as \"mongodb\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | name: mongo-db\n 3 | type: mongodb\n 4 | uri: mongodb+srv://username:password@host/dbname",
err: "unable to parse source \"mongo-db\" as \"mongodb\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | kind: mongodb\n 3 | uri: mongodb+srv://username:password@host/dbname",
},
{
desc: "missing required field",
in: `
kind: sources
name: mongo-db
type: mongodb
sources:
mongo-db:
kind: mongodb
`,
err: "error unmarshaling sources: unable to parse source \"mongo-db\" as \"mongodb\": Key: 'Config.Uri' Error:Field validation for 'Uri' failed on the 'required' tag",
err: "unable to parse source \"mongo-db\" as \"mongodb\": Key: 'Config.Uri' Error:Field validation for 'Uri' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -28,14 +28,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "mssql"
const SourceKind string = "mssql"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -50,7 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
// Cloud SQL MSSQL configs
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
@@ -59,9 +59,9 @@ type Config struct {
Encrypt string `yaml:"encrypt"`
}
func (r Config) SourceConfigType() string {
// Returns Cloud SQL MSSQL source type
return SourceType
func (r Config) SourceConfigKind() string {
// Returns Cloud SQL MSSQL source kind
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -91,9 +91,9 @@ type Source struct {
Db *sql.DB
}
func (s *Source) SourceType() string {
// Returns Cloud SQL MSSQL source type
return SourceType
func (s *Source) SourceKind() string {
// Returns Cloud SQL MSSQL source kind
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -156,7 +156,7 @@ func initMssqlConnection(
error,
) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
userAgent, err := util.UserAgentFromContext(ctx)

View File

@@ -15,12 +15,11 @@
package mssql_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/mssql"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -34,19 +33,19 @@ func TestParseFromYamlMssql(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-mssql-instance
type: mssql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
sources:
my-mssql-instance:
kind: mssql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-mssql-instance": mssql.Config{
Name: "my-mssql-instance",
Type: mssql.SourceType,
Kind: mssql.SourceKind,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -58,20 +57,20 @@ func TestParseFromYamlMssql(t *testing.T) {
{
desc: "with encrypt field",
in: `
kind: sources
name: my-mssql-instance
type: mssql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
encrypt: strict
sources:
my-mssql-instance:
kind: mssql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
encrypt: strict
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-mssql-instance": mssql.Config{
Name: "my-mssql-instance",
Type: mssql.SourceType,
Kind: mssql.SourceKind,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -84,12 +83,16 @@ func TestParseFromYamlMssql(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect psarse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect psarse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -104,35 +107,39 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-mssql-instance
type: mssql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
foo: bar
sources:
my-mssql-instance:
kind: mssql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-mssql-instance\" as \"mssql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | name: my-mssql-instance\n 5 | password: my_pass\n 6 | ",
err: "unable to parse source \"my-mssql-instance\" as \"mssql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | kind: mssql\n 5 | password: my_pass\n 6 | ",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-mssql-instance
type: mssql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
sources:
my-mssql-instance:
kind: mssql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
`,
err: "error unmarshaling sources: unable to parse source \"my-mssql-instance\" as \"mssql\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
err: "unable to parse source \"my-mssql-instance\" as \"mssql\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -30,14 +30,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "mysql"
const SourceKind string = "mysql"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -51,7 +51,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
@@ -61,8 +61,8 @@ type Config struct {
QueryParams map[string]string `yaml:"queryParams"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -90,8 +90,8 @@ type Source struct {
Pool *sql.DB
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -158,7 +158,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
func initMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string, queryParams map[string]string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Build query parameters via url.Values for deterministic order and proper escaping.

View File

@@ -19,12 +19,12 @@ import (
"strings"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"go.opentelemetry.io/otel/trace/noop"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/mysql"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -38,19 +38,19 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-mysql-instance
type: mysql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
sources:
my-mysql-instance:
kind: mysql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-mysql-instance": mysql.Config{
Name: "my-mysql-instance",
Type: mysql.SourceType,
Kind: mysql.SourceKind,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -62,20 +62,20 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
{
desc: "with query timeout",
in: `
kind: sources
name: my-mysql-instance
type: mysql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
queryTimeout: 45s
sources:
my-mysql-instance:
kind: mysql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
queryTimeout: 45s
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-mysql-instance": mysql.Config{
Name: "my-mysql-instance",
Type: mysql.SourceType,
Kind: mysql.SourceKind,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -88,22 +88,22 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
{
desc: "with query params",
in: `
kind: sources
name: my-mysql-instance
type: mysql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
queryParams:
tls: preferred
charset: utf8mb4
sources:
my-mysql-instance:
kind: mysql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
queryParams:
tls: preferred
charset: utf8mb4
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-mysql-instance": mysql.Config{
Name: "my-mysql-instance",
Type: mysql.SourceType,
Kind: mysql.SourceKind,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -120,11 +120,15 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got, cmpopts.EquateEmpty()); diff != "" {
if diff := cmp.Diff(tc.want, got.Sources, cmpopts.EquateEmpty()); diff != "" {
t.Fatalf("mismatch (-want +got):\n%s", diff)
}
})
@@ -141,51 +145,55 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-mysql-instance
type: mysql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
foo: bar
sources:
my-mysql-instance:
kind: mysql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-mysql-instance\" as \"mysql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | name: my-mysql-instance\n 5 | password: my_pass\n 6 | ",
err: "unknown field \"foo\"",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-mysql-instance
type: mysql
port: my-port
database: my_db
user: my_user
password: my_pass
sources:
my-mysql-instance:
kind: mysql
port: my-port
database: my_db
user: my_user
password: my_pass
`,
err: "error unmarshaling sources: unable to parse source \"my-mysql-instance\" as \"mysql\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
err: "Field validation for 'Host' failed",
},
{
desc: "invalid query params type",
in: `
kind: sources
name: my-mysql-instance
type: mysql
host: 0.0.0.0
port: 3306
database: my_db
user: my_user
password: my_pass
queryParams: not-a-map
sources:
my-mysql-instance:
kind: mysql
host: 0.0.0.0
port: 3306
database: my_db
user: my_user
password: my_pass
queryParams: not-a-map
`,
err: "error unmarshaling sources: unable to parse source \"my-mysql-instance\" as \"mysql\": [6:14] string was used where mapping is expected\n 3 | name: my-mysql-instance\n 4 | password: my_pass\n 5 | port: 3306\n> 6 | queryParams: not-a-map\n ^\n 7 | type: mysql\n 8 | user: my_user",
err: "string was used where mapping is expected",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}
@@ -203,7 +211,7 @@ func TestFailInitialization(t *testing.T) {
cfg := mysql.Config{
Name: "instance",
Type: "mysql",
Kind: "mysql",
Host: "localhost",
Port: "3306",
Database: "db",

View File

@@ -29,7 +29,7 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "neo4j"
const SourceKind string = "neo4j"
var sourceClassifier *classifier.QueryClassifier = classifier.NewQueryClassifier()
@@ -37,8 +37,8 @@ var sourceClassifier *classifier.QueryClassifier = classifier.NewQueryClassifier
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -52,15 +52,15 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Uri string `yaml:"uri" validate:"required"`
User string `yaml:"user" validate:"required"`
Password string `yaml:"password" validate:"required"`
Database string `yaml:"database" validate:"required"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -91,8 +91,8 @@ type Source struct {
Driver neo4j.DriverWithContext
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -182,7 +182,7 @@ func addPlanChildren(p neo4j.Plan) []map[string]any {
func initNeo4jDriver(ctx context.Context, tracer trace.Tracer, uri, user, password, name string) (neo4j.DriverWithContext, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
auth := neo4j.BasicAuth(user, password, "")

View File

@@ -15,12 +15,11 @@
package neo4j_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/neo4j"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -34,18 +33,18 @@ func TestParseFromYamlNeo4j(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-neo4j-instance
type: neo4j
uri: neo4j+s://my-host:7687
database: my_db
user: my_user
password: my_pass
sources:
my-neo4j-instance:
kind: neo4j
uri: neo4j+s://my-host:7687
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-neo4j-instance": neo4j.Config{
Name: "my-neo4j-instance",
Type: neo4j.SourceType,
Kind: neo4j.SourceKind,
Uri: "neo4j+s://my-host:7687",
Database: "my_db",
User: "my_user",
@@ -56,12 +55,16 @@ func TestParseFromYamlNeo4j(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -77,33 +80,37 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-neo4j-instance
type: neo4j
uri: neo4j+s://my-host:7687
database: my_db
user: my_user
password: my_pass
foo: bar
sources:
my-neo4j-instance:
kind: neo4j
uri: neo4j+s://my-host:7687
database: my_db
user: my_user
password: my_pass
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-neo4j-instance\" as \"neo4j\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | name: my-neo4j-instance\n 4 | password: my_pass\n 5 | type: neo4j\n 6 | ",
err: "unable to parse source \"my-neo4j-instance\" as \"neo4j\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | kind: neo4j\n 4 | password: my_pass\n 5 | uri: neo4j+s://my-host:7687\n 6 | ",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-neo4j-instance
type: neo4j
uri: neo4j+s://my-host:7687
database: my_db
user: my_user
sources:
my-neo4j-instance:
kind: neo4j
uri: neo4j+s://my-host:7687
database: my_db
user: my_user
`,
err: "error unmarshaling sources: unable to parse source \"my-neo4j-instance\" as \"neo4j\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
err: "unable to parse source \"my-neo4j-instance\" as \"neo4j\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -27,14 +27,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "oceanbase"
const SourceKind string = "oceanbase"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -48,7 +48,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
@@ -57,8 +57,8 @@ type Config struct {
QueryTimeout string `yaml:"queryTimeout"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -86,8 +86,8 @@ type Source struct {
Pool *sql.DB
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -153,7 +153,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
}
func initOceanBaseConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
_, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
_, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", user, pass, host, port, dbname)

View File

@@ -15,12 +15,11 @@
package oceanbase_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/oceanbase"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -35,19 +34,19 @@ func TestParseFromYamlOceanBase(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-oceanbase-instance
type: oceanbase
host: 0.0.0.0
port: 2881
database: ob_db
user: ob_user
password: ob_pass
sources:
my-oceanbase-instance:
kind: oceanbase
host: 0.0.0.0
port: 2881
database: ob_db
user: ob_user
password: ob_pass
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-oceanbase-instance": oceanbase.Config{
Name: "my-oceanbase-instance",
Type: oceanbase.SourceType,
Kind: oceanbase.SourceKind,
Host: "0.0.0.0",
Port: "2881",
Database: "ob_db",
@@ -59,20 +58,20 @@ func TestParseFromYamlOceanBase(t *testing.T) {
{
desc: "with query timeout",
in: `
kind: sources
name: my-oceanbase-instance
type: oceanbase
host: 0.0.0.0
port: 2881
database: ob_db
user: ob_user
password: ob_pass
queryTimeout: 30s
sources:
my-oceanbase-instance:
kind: oceanbase
host: 0.0.0.0
port: 2881
database: ob_db
user: ob_user
password: ob_pass
queryTimeout: 30s
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-oceanbase-instance": oceanbase.Config{
Name: "my-oceanbase-instance",
Type: oceanbase.SourceType,
Kind: oceanbase.SourceKind,
Host: "0.0.0.0",
Port: "2881",
Database: "ob_db",
@@ -85,12 +84,16 @@ func TestParseFromYamlOceanBase(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -106,35 +109,39 @@ func TestFailParseFromYamlOceanBase(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-oceanbase-instance
type: oceanbase
host: 0.0.0.0
port: 2881
database: ob_db
user: ob_user
password: ob_pass
foo: bar
sources:
my-oceanbase-instance:
kind: oceanbase
host: 0.0.0.0
port: 2881
database: ob_db
user: ob_user
password: ob_pass
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-oceanbase-instance\" as \"oceanbase\": [2:1] unknown field \"foo\"\n 1 | database: ob_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | name: my-oceanbase-instance\n 5 | password: ob_pass\n 6 | ",
err: "unable to parse source \"my-oceanbase-instance\" as \"oceanbase\": [2:1] unknown field \"foo\"\n 1 | database: ob_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | kind: oceanbase\n 5 | password: ob_pass\n 6 | ",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-oceanbase-instance
type: oceanbase
port: 2881
database: ob_db
user: ob_user
password: ob_pass
sources:
my-oceanbase-instance:
kind: oceanbase
port: 2881
database: ob_db
user: ob_user
password: ob_pass
`,
err: "error unmarshaling sources: unable to parse source \"my-oceanbase-instance\" as \"oceanbase\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
err: "unable to parse source \"my-oceanbase-instance\" as \"oceanbase\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -18,14 +18,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "oracle"
const SourceKind string = "oracle"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -45,7 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
ConnectionString string `yaml:"connectionString,omitempty"`
TnsAlias string `yaml:"tnsAlias,omitempty"`
TnsAdmin string `yaml:"tnsAdmin,omitempty"`
@@ -95,8 +95,8 @@ func (c Config) validate() error {
return nil
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -124,8 +124,8 @@ type Source struct {
DB *sql.DB
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -239,7 +239,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Config) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, config.Name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, config.Name)
defer span.End()
logger, err := util.LoggerFromContext(ctx)

View File

@@ -3,13 +3,12 @@
package oracle_test
import (
"context"
"strings"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/oracle"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -23,18 +22,18 @@ func TestParseFromYamlOracle(t *testing.T) {
{
desc: "connection string and useOCI=true",
in: `
kind: sources
name: my-oracle-cs
type: oracle
connectionString: "my-host:1521/XEPDB1"
user: my_user
password: my_pass
useOCI: true
`,
want: map[string]sources.SourceConfig{
sources:
my-oracle-cs:
kind: oracle
connectionString: "my-host:1521/XEPDB1"
user: my_user
password: my_pass
useOCI: true
`,
want: server.SourceConfigs{
"my-oracle-cs": oracle.Config{
Name: "my-oracle-cs",
Type: oracle.SourceType,
Kind: oracle.SourceKind,
ConnectionString: "my-host:1521/XEPDB1",
User: "my_user",
Password: "my_pass",
@@ -45,19 +44,19 @@ func TestParseFromYamlOracle(t *testing.T) {
{
desc: "host/port/serviceName and default useOCI=false",
in: `
kind: sources
name: my-oracle-host
type: oracle
host: my-host
port: 1521
serviceName: ORCLPDB
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
sources:
my-oracle-host:
kind: oracle
host: my-host
port: 1521
serviceName: ORCLPDB
user: my_user
password: my_pass
`,
want: server.SourceConfigs{
"my-oracle-host": oracle.Config{
Name: "my-oracle-host",
Type: oracle.SourceType,
Kind: oracle.SourceKind,
Host: "my-host",
Port: 1521,
ServiceName: "ORCLPDB",
@@ -70,19 +69,19 @@ func TestParseFromYamlOracle(t *testing.T) {
{
desc: "tnsAlias and TnsAdmin specified with explicit useOCI=true",
in: `
kind: sources
name: my-oracle-tns-oci
type: oracle
tnsAlias: FINANCE_DB
tnsAdmin: /opt/oracle/network/admin
user: my_user
password: my_pass
useOCI: true
`,
want: map[string]sources.SourceConfig{
sources:
my-oracle-tns-oci:
kind: oracle
tnsAlias: FINANCE_DB
tnsAdmin: /opt/oracle/network/admin
user: my_user
password: my_pass
useOCI: true
`,
want: server.SourceConfigs{
"my-oracle-tns-oci": oracle.Config{
Name: "my-oracle-tns-oci",
Type: oracle.SourceType,
Kind: oracle.SourceKind,
TnsAlias: "FINANCE_DB",
TnsAdmin: "/opt/oracle/network/admin",
User: "my_user",
@@ -94,18 +93,22 @@ func TestParseFromYamlOracle(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse:\nwant: %v\ngot: %v\ndiff: %s", tc.want, got, cmp.Diff(tc.want, got))
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse:\nwant: %v\ngot: %v\ndiff: %s", tc.want, got.Sources, cmp.Diff(tc.want, got.Sources))
}
})
}
}
func TestFailParseFromYaml(t *testing.T) {
func TestFailParseFromYamlOracle(t *testing.T) {
tcs := []struct {
desc string
in string
@@ -114,72 +117,76 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-oracle-instance
type: oracle
host: my-host
serviceName: ORCL
user: my_user
password: my_pass
extraField: value
sources:
my-oracle-instance:
kind: oracle
host: my-host
serviceName: ORCL
user: my_user
password: my_pass
extraField: value
`,
err: "error unmarshaling sources: unable to parse source \"my-oracle-instance\" as \"oracle\": [1:1] unknown field \"extraField\"\n> 1 | extraField: value\n ^\n 2 | host: my-host\n 3 | name: my-oracle-instance\n 4 | password: my_pass\n 5 | ",
err: "unable to parse source \"my-oracle-instance\" as \"oracle\": [1:1] unknown field \"extraField\"\n> 1 | extraField: value\n ^\n 2 | host: my-host\n 3 | kind: oracle\n 4 | password: my_pass\n 5 | ",
},
{
desc: "missing required password field",
in: `
kind: sources
name: my-oracle-instance
type: oracle
host: my-host
serviceName: ORCL
user: my_user
`,
err: "error unmarshaling sources: unable to parse source \"my-oracle-instance\" as \"oracle\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
sources:
my-oracle-instance:
kind: oracle
host: my-host
serviceName: ORCL
user: my_user
`,
err: "unable to parse source \"my-oracle-instance\" as \"oracle\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
},
{
desc: "missing connection method fields (validate fails)",
in: `
kind: sources
name: my-oracle-instance
type: oracle
user: my_user
password: my_pass
`,
err: "error unmarshaling sources: unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: must provide one of: 'tns_alias', 'connection_string', or both 'host' and 'service_name'",
sources:
my-oracle-instance:
kind: oracle
user: my_user
password: my_pass
`,
err: "unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: must provide one of: 'tns_alias', 'connection_string', or both 'host' and 'service_name'",
},
{
desc: "multiple connection methods provided (validate fails)",
in: `
kind: sources
name: my-oracle-instance
type: oracle
host: my-host
serviceName: ORCL
connectionString: "my-host:1521/XEPDB1"
user: my_user
password: my_pass
`,
err: "error unmarshaling sources: unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: provide only one connection method: 'tns_alias', 'connection_string', or 'host'+'service_name'",
sources:
my-oracle-instance:
kind: oracle
host: my-host
serviceName: ORCL
connectionString: "my-host:1521/XEPDB1"
user: my_user
password: my_pass
`,
err: "unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: provide only one connection method: 'tns_alias', 'connection_string', or 'host'+'service_name'",
},
{
desc: "fail on tnsAdmin with useOCI=false",
in: `
kind: sources
name: my-oracle-fail
type: oracle
tnsAlias: FINANCE_DB
tnsAdmin: /opt/oracle/network/admin
user: my_user
password: my_pass
useOCI: false
sources:
my-oracle-fail:
kind: oracle
tnsAlias: FINANCE_DB
tnsAdmin: /opt/oracle/network/admin
user: my_user
password: my_pass
useOCI: false
`,
err: "error unmarshaling sources: unable to parse source \"my-oracle-fail\" as \"oracle\": invalid Oracle configuration: `tnsAdmin` can only be used when `UseOCI` is true, or use `walletLocation` instead",
err: "unable to parse source \"my-oracle-fail\" as \"oracle\": invalid Oracle configuration: `tnsAdmin` can only be used when `UseOCI` is true, or use `walletLocation` instead",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -28,14 +28,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "postgres"
const SourceKind string = "postgres"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -49,7 +49,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
@@ -58,8 +58,8 @@ type Config struct {
QueryParams map[string]string `yaml:"queryParams"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -87,8 +87,8 @@ type Source struct {
Pool *pgxpool.Pool
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -128,7 +128,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string) (*pgxpool.Pool, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
userAgent, err := util.UserAgentFromContext(ctx)
if err != nil {

View File

@@ -1,4 +1,4 @@
// Copyright 2025 Google LLC
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -15,14 +15,13 @@
package postgres_test
import (
"context"
"sort"
"strings"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/postgres"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -36,19 +35,19 @@ func TestParseFromYamlPostgres(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-pg-instance
type: postgres
host: my-host
port: my-port
database: my_db
user: my_user
password: my_pass
sources:
my-pg-instance:
kind: postgres
host: my-host
port: my-port
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-pg-instance": postgres.Config{
Name: "my-pg-instance",
Type: postgres.SourceType,
Kind: postgres.SourceKind,
Host: "my-host",
Port: "my-port",
Database: "my_db",
@@ -60,22 +59,22 @@ func TestParseFromYamlPostgres(t *testing.T) {
{
desc: "example with query params",
in: `
kind: sources
name: my-pg-instance
type: postgres
host: my-host
port: my-port
database: my_db
user: my_user
password: my_pass
queryParams:
sslmode: verify-full
sslrootcert: /tmp/ca.crt
sources:
my-pg-instance:
kind: postgres
host: my-host
port: my-port
database: my_db
user: my_user
password: my_pass
queryParams:
sslmode: verify-full
sslrootcert: /tmp/ca.crt
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-pg-instance": postgres.Config{
Name: "my-pg-instance",
Type: postgres.SourceType,
Kind: postgres.SourceKind,
Host: "my-host",
Port: "my-port",
Database: "my_db",
@@ -91,12 +90,16 @@ func TestParseFromYamlPostgres(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -112,35 +115,39 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-pg-instance
type: postgres
host: my-host
port: my-port
database: my_db
user: my_user
password: my_pass
foo: bar
sources:
my-pg-instance:
kind: postgres
host: my-host
port: my-port
database: my_db
user: my_user
password: my_pass
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"postgres\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: my-host\n 4 | name: my-pg-instance\n 5 | password: my_pass\n 6 | ",
err: "unable to parse source \"my-pg-instance\" as \"postgres\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: my-host\n 4 | kind: postgres\n 5 | password: my_pass\n 6 | ",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-pg-instance
type: postgres
host: my-host
port: my-port
database: my_db
user: my_user
sources:
my-pg-instance:
kind: postgres
host: my-host
port: my-port
database: my_db
user: my_user
`,
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"postgres\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
err: "unable to parse source \"my-pg-instance\" as \"postgres\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -24,14 +24,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "redis"
const SourceKind string = "redis"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -45,7 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Address []string `yaml:"address" validate:"required"`
Username string `yaml:"username"`
Password string `yaml:"password"`
@@ -54,8 +54,8 @@ type Config struct {
ClusterEnabled bool `yaml:"clusterEnabled"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
// RedisClient is an interface for `redis.Client` and `redis.ClusterClient
@@ -141,8 +141,8 @@ type Source struct {
Client RedisClient
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -15,13 +15,12 @@
package redis_test
import (
"context"
"strings"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/redis"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -35,16 +34,16 @@ func TestParseFromYamlRedis(t *testing.T) {
{
desc: "default setting",
in: `
kind: sources
name: my-redis-instance
type: redis
address:
- 127.0.0.1
sources:
my-redis-instance:
kind: redis
address:
- 127.0.0.1
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-redis-instance": redis.Config{
Name: "my-redis-instance",
Type: redis.SourceType,
Kind: redis.SourceKind,
Address: []string{"127.0.0.1"},
ClusterEnabled: false,
UseGCPIAM: false,
@@ -54,20 +53,20 @@ func TestParseFromYamlRedis(t *testing.T) {
{
desc: "advanced example",
in: `
kind: sources
name: my-redis-instance
type: redis
address:
- 127.0.0.1
password: my-pass
database: 1
useGCPIAM: true
clusterEnabled: true
sources:
my-redis-instance:
kind: redis
address:
- 127.0.0.1
password: my-pass
database: 1
useGCPIAM: true
clusterEnabled: true
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-redis-instance": redis.Config{
Name: "my-redis-instance",
Type: redis.SourceType,
Kind: redis.SourceKind,
Address: []string{"127.0.0.1"},
Password: "my-pass",
Database: 1,
@@ -79,12 +78,16 @@ func TestParseFromYamlRedis(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -100,43 +103,48 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "invalid database",
in: `
kind: sources
name: my-redis-instance
type: redis
address:
- 127.0.0.1
password: my-pass
database: data
sources:
my-redis-instance:
kind: redis
project: my-project
address:
- 127.0.0.1
password: my-pass
database: data
`,
err: "error unmarshaling sources: unable to parse source \"my-redis-instance\" as \"redis\": [3:11] cannot unmarshal string into Go struct field Config.Database of type int\n 1 | address:\n 2 | - 127.0.0.1\n> 3 | database: data\n ^\n 4 | name: my-redis-instance\n 5 | password: my-pass\n 6 | type: redis",
err: "cannot unmarshal string into Go struct field .Sources of type int",
},
{
desc: "extra field",
in: `
kind: sources
name: my-redis-instance
type: redis
project: my-project
address:
- 127.0.0.1
password: my-pass
database: 1
sources:
my-redis-instance:
kind: redis
project: my-project
address:
- 127.0.0.1
password: my-pass
database: 1
`,
err: "error unmarshaling sources: unable to parse source \"my-redis-instance\" as \"redis\": [6:1] unknown field \"project\"\n 3 | database: 1\n 4 | name: my-redis-instance\n 5 | password: my-pass\n> 6 | project: my-project\n ^\n 7 | type: redis",
err: "unable to parse source \"my-redis-instance\" as \"redis\": [6:1] unknown field \"project\"",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-redis-instance
type: redis
sources:
my-redis-instance:
kind: redis
`,
err: "error unmarshaling sources: unable to parse source \"my-redis-instance\" as \"redis\": Key: 'Config.Address' Error:Field validation for 'Address' failed on the 'required' tag",
err: "unable to parse source \"my-redis-instance\" as \"redis\": Key: 'Config.Address' Error:Field validation for 'Address' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -33,14 +33,14 @@ import (
"google.golang.org/protobuf/encoding/protojson"
)
const SourceType string = "serverless-spark"
const SourceKind string = "serverless-spark"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -54,13 +54,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Location string `yaml:"location" validate:"required"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -94,8 +94,8 @@ type Source struct {
OpsClient *longrunning.OperationsClient
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -15,12 +15,11 @@
package serverlessspark_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -34,16 +33,16 @@ func TestParseFromYamlServerlessSpark(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-instance
type: serverless-spark
project: my-project
location: my-location
sources:
my-instance:
kind: serverless-spark
project: my-project
location: my-location
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-instance": serverlessspark.Config{
Name: "my-instance",
Type: serverlessspark.SourceType,
Kind: serverlessspark.SourceKind,
Project: "my-project",
Location: "my-location",
},
@@ -52,12 +51,16 @@ func TestParseFromYamlServerlessSpark(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -73,39 +76,43 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-instance
type: serverless-spark
project: my-project
location: my-location
foo: bar
sources:
my-instance:
kind: serverless-spark
project: my-project
location: my-location
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"serverless-spark\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | location: my-location\n 3 | name: my-instance\n 4 | project: my-project\n 5 | ",
err: "unable to parse source \"my-instance\" as \"serverless-spark\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | kind: serverless-spark\n 3 | location: my-location\n 4 | project: my-project",
},
{
desc: "missing required field project",
in: `
kind: sources
name: my-instance
type: serverless-spark
location: my-location
sources:
my-instance:
kind: serverless-spark
location: my-location
`,
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"serverless-spark\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
err: "unable to parse source \"my-instance\" as \"serverless-spark\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
},
{
desc: "missing required field location",
in: `
kind: sources
name: my-instance
type: serverless-spark
project: my-project
sources:
my-instance:
kind: serverless-spark
project: my-project
`,
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"serverless-spark\": Key: 'Config.Location' Error:Field validation for 'Location' failed on the 'required' tag",
err: "unable to parse source \"my-instance\" as \"serverless-spark\": Key: 'Config.Location' Error:Field validation for 'Location' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -29,15 +29,15 @@ import (
"go.opentelemetry.io/otel/trace"
)
// SourceType for SingleStore source
const SourceType string = "singlestore"
// SourceKind for SingleStore source
const SourceKind string = "singlestore"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -52,7 +52,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
// Config holds the configuration parameters for connecting to a SingleStore database.
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
@@ -61,9 +61,9 @@ type Config struct {
QueryTimeout string `yaml:"queryTimeout"`
}
// SourceConfigType returns the type of the source configuration.
func (r Config) SourceConfigType() string {
return SourceType
// SourceConfigKind returns the kind of the source configuration.
func (r Config) SourceConfigKind() string {
return SourceKind
}
// Initialize sets up the SingleStore connection pool and returns a Source.
@@ -93,9 +93,9 @@ type Source struct {
Pool *sql.DB
}
// SourceType returns the type of the source configuration.
func (s *Source) SourceType() string {
return SourceType
// SourceKind returns the kind of the source configuration.
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -162,7 +162,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
func initSingleStoreConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Configure the driver to connect to the database

View File

@@ -1,3 +1,5 @@
package singlestore_test
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,15 +14,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package singlestore_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/singlestore"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -34,19 +33,19 @@ func TestParseFromYaml(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-s2-instance
type: singlestore
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
sources:
my-s2-instance:
kind: singlestore
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-s2-instance": singlestore.Config{
Name: "my-s2-instance",
Type: singlestore.SourceType,
Kind: singlestore.SourceKind,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -58,20 +57,20 @@ func TestParseFromYaml(t *testing.T) {
{
desc: "with query timeout",
in: `
kind: sources
name: my-s2-instance
type: singlestore
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
queryTimeout: 45s
sources:
my-s2-instance:
kind: singlestore
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
queryTimeout: 45s
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-s2-instance": singlestore.Config{
Name: "my-s2-instance",
Type: singlestore.SourceType,
Kind: singlestore.SourceKind,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -84,12 +83,16 @@ func TestParseFromYaml(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -105,35 +108,39 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-s2-instance
type: singlestore
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
foo: bar
sources:
my-s2-instance:
kind: singlestore
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-s2-instance\" as \"singlestore\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | name: my-s2-instance\n 5 | password: my_pass\n 6 | ",
err: "unable to parse source \"my-s2-instance\" as \"singlestore\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | kind: singlestore\n 5 | password: my_pass\n 6 | ",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-s2-instance
type: singlestore
port: my-port
database: my_db
user: my_user
password: my_pass
sources:
my-s2-instance:
kind: singlestore
port: my-port
database: my_db
user: my_user
password: my_pass
`,
err: "error unmarshaling sources: unable to parse source \"my-s2-instance\" as \"singlestore\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
err: "unable to parse source \"my-s2-instance\" as \"singlestore\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -25,14 +25,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceType string = "snowflake"
const SourceKind string = "snowflake"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -46,7 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Account string `yaml:"account" validate:"required"`
User string `yaml:"user" validate:"required"`
Password string `yaml:"password" validate:"required"`
@@ -56,8 +56,8 @@ type Config struct {
Role string `yaml:"role"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -85,8 +85,8 @@ type Source struct {
DB *sqlx.DB
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -137,7 +137,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
func initSnowflakeConnection(ctx context.Context, tracer trace.Tracer, name, account, user, password, database, schema, warehouse, role string) (*sqlx.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Set defaults for optional parameters

View File

@@ -15,12 +15,11 @@
package snowflake_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/snowflake"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -34,19 +33,19 @@ func TestParseFromYamlSnowflake(t *testing.T) {
{
desc: "basic example",
in: `
kind: sources
name: my-snowflake-instance
type: snowflake
account: my-account
user: my_user
password: my_pass
database: my_db
schema: my_schema
sources:
my-snowflake-instance:
kind: snowflake
account: my-account
user: my_user
password: my_pass
database: my_db
schema: my_schema
`,
want: map[string]sources.SourceConfig{
want: server.SourceConfigs{
"my-snowflake-instance": snowflake.Config{
Name: "my-snowflake-instance",
Type: snowflake.SourceType,
Kind: snowflake.SourceKind,
Account: "my-account",
User: "my_user",
Password: "my_pass",
@@ -60,12 +59,16 @@ func TestParseFromYamlSnowflake(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
@@ -81,35 +84,39 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
kind: sources
name: my-snowflake-instance
type: snowflake
account: my-account
user: my_user
password: my_pass
database: my_db
schema: my_schema
foo: bar
sources:
my-snowflake-instance:
kind: snowflake
account: my-account
user: my_user
password: my_pass
database: my_db
schema: my_schema
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-snowflake-instance\" as \"snowflake\": [3:1] unknown field \"foo\"\n 1 | account: my-account\n 2 | database: my_db\n> 3 | foo: bar\n ^\n 4 | name: my-snowflake-instance\n 5 | password: my_pass\n 6 | schema: my_schema\n 7 | ",
err: "unable to parse source \"my-snowflake-instance\" as \"snowflake\": [3:1] unknown field \"foo\"\n 1 | account: my-account\n 2 | database: my_db\n> 3 | foo: bar\n ^\n 4 | kind: snowflake\n 5 | password: my_pass\n 6 | schema: my_schema\n 7 | ",
},
{
desc: "missing required field",
in: `
kind: sources
name: my-snowflake-instance
type: snowflake
account: my-account
user: my_user
password: my_pass
database: my_db
sources:
my-snowflake-instance:
kind: snowflake
account: my-account
user: my_user
password: my_pass
database: my_db
`,
err: "error unmarshaling sources: unable to parse source \"my-snowflake-instance\" as \"snowflake\": Key: 'Config.Schema' Error:Field validation for 'Schema' failed on the 'required' tag",
err: "unable to parse source \"my-snowflake-instance\" as \"snowflake\": Key: 'Config.Schema' Error:Field validation for 'Schema' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -29,48 +29,48 @@ type SourceConfigFactory func(ctx context.Context, name string, decoder *yaml.De
var sourceRegistry = make(map[string]SourceConfigFactory)
// Register registers a new source type with its factory.
// It returns false if the type is already registered.
func Register(sourceType string, factory SourceConfigFactory) bool {
if _, exists := sourceRegistry[sourceType]; exists {
// Source with this type already exists, do not overwrite.
// Register registers a new source kind with its factory.
// It returns false if the kind is already registered.
func Register(kind string, factory SourceConfigFactory) bool {
if _, exists := sourceRegistry[kind]; exists {
// Source with this kind already exists, do not overwrite.
return false
}
sourceRegistry[sourceType] = factory
sourceRegistry[kind] = factory
return true
}
// DecodeConfig decodes a source configuration using the registered factory for the given type.
func DecodeConfig(ctx context.Context, sourceType string, name string, decoder *yaml.Decoder) (SourceConfig, error) {
factory, found := sourceRegistry[sourceType]
// DecodeConfig decodes a source configuration using the registered factory for the given kind.
func DecodeConfig(ctx context.Context, kind string, name string, decoder *yaml.Decoder) (SourceConfig, error) {
factory, found := sourceRegistry[kind]
if !found {
return nil, fmt.Errorf("unknown source type: %q", sourceType)
return nil, fmt.Errorf("unknown source kind: %q", kind)
}
sourceConfig, err := factory(ctx, name, decoder)
if err != nil {
return nil, fmt.Errorf("unable to parse source %q as %q: %w", name, sourceType, err)
return nil, fmt.Errorf("unable to parse source %q as %q: %w", name, kind, err)
}
return sourceConfig, err
}
// SourceConfig is the interface for configuring a source.
type SourceConfig interface {
SourceConfigType() string
SourceConfigKind() string
Initialize(ctx context.Context, tracer trace.Tracer) (Source, error)
}
// Source is the interface for the source itself.
type Source interface {
SourceType() string
SourceKind() string
ToConfig() SourceConfig
}
// InitConnectionSpan adds a span for database pool connection initialization
func InitConnectionSpan(ctx context.Context, tracer trace.Tracer, sourceType, sourceName string) (context.Context, trace.Span) {
func InitConnectionSpan(ctx context.Context, tracer trace.Tracer, sourceKind, sourceName string) (context.Context, trace.Span) {
ctx, span := tracer.Start(
ctx,
"toolbox/server/source/connect",
trace.WithAttributes(attribute.String("source_type", sourceType)),
trace.WithAttributes(attribute.String("source_kind", sourceKind)),
trace.WithAttributes(attribute.String("source_name", sourceName)),
)
return ctx, span

View File

@@ -28,14 +28,14 @@ import (
"google.golang.org/api/iterator"
)
const SourceType string = "spanner"
const SourceKind string = "spanner"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
@@ -49,15 +49,15 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Instance string `yaml:"instance" validate:"required"`
Dialect sources.Dialect `yaml:"dialect" validate:"required"`
Database string `yaml:"database" validate:"required"`
}
func (r Config) SourceConfigType() string {
return SourceType
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -80,8 +80,8 @@ type Source struct {
Client *spanner.Client
}
func (s *Source) SourceType() string {
return SourceType
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -171,7 +171,7 @@ func (s *Source) RunSQL(ctx context.Context, readOnly bool, statement string, pa
func initSpannerClient(ctx context.Context, tracer trace.Tracer, name, project, instance, dbname string) (*spanner.Client, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Configure the connection to the database

Some files were not shown because too many files have changed in this diff Show More