Compare commits

..

5 Commits

Author SHA1 Message Date
Yuan Teoh
ad8df40791 chore: update yaml tag for auth, embedding model, prompts, sources 2026-01-21 22:49:22 -08:00
Yuan Teoh
c29355ff82 chore: update unmarshal function for ToolsFile 2026-01-21 22:49:06 -08:00
Yuan Teoh
70f5550910 Update cmd/root_test.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-20 11:25:03 -08:00
Yuan Teoh
348c9fde08 chore: add preprocessing function to convert tools file 2026-01-20 11:25:03 -08:00
Yuan Teoh
aef539bcf3 refactor!: update Kind field to Type in source code (#2312)
Update source code `Kind` to `Type`. It's only changes within our code.
Changes to yaml tag (that will affect users) will be done in later PRs.

This is a breaking change since it updates telemetry's span attribute
from `source_kind` to `source_type`.

Related #817

Future updates will include: 
* Updating a preprocessing function to convert config file from v1 to v2
* Update unmarshal function for ToolsFile to convert config file (test
will fail since the yaml tag is not yet updated).
* Update yaml tag (test will pass).
2026-01-20 11:20:41 -08:00
550 changed files with 5636 additions and 8952 deletions

View File

@@ -87,7 +87,7 @@ steps:
- "CLOUD_SQL_POSTGRES_REGION=$_REGION"
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
secretEnv:
["CLOUD_SQL_POSTGRES_USER", "CLOUD_SQL_POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
["CLOUD_SQL_POSTGRES_USER", "CLOUD_SQL_POSTGRES_PASS", "CLIENT_ID"]
volumes:
- name: "go"
path: "/gopath"
@@ -134,7 +134,7 @@ steps:
- "ALLOYDB_POSTGRES_DATABASE=$_DATABASE_NAME"
- "ALLOYDB_POSTGRES_REGION=$_REGION"
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
secretEnv: ["ALLOYDB_POSTGRES_USER", "ALLOYDB_POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
secretEnv: ["ALLOYDB_POSTGRES_USER", "ALLOYDB_POSTGRES_PASS", "CLIENT_ID"]
volumes:
- name: "go"
path: "/gopath"
@@ -293,7 +293,7 @@ steps:
.ci/test_with_coverage.sh \
"Cloud Healthcare API" \
cloudhealthcare \
cloudhealthcare
cloudhealthcare || echo "Integration tests failed."
- id: "postgres"
name: golang:1
@@ -305,7 +305,7 @@ steps:
- "POSTGRES_HOST=$_POSTGRES_HOST"
- "POSTGRES_PORT=$_POSTGRES_PORT"
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
secretEnv: ["POSTGRES_USER", "POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
secretEnv: ["POSTGRES_USER", "POSTGRES_PASS", "CLIENT_ID"]
volumes:
- name: "go"
path: "/gopath"
@@ -964,13 +964,6 @@ steps:
availableSecrets:
secretManager:
# Common secrets
- versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest
env: CLIENT_ID
- versionName: projects/$PROJECT_ID/secrets/api_key/versions/latest
env: API_KEY
# Resource-specific secrets
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
env: CLOUD_SQL_POSTGRES_USER
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_pass/versions/latest
@@ -987,6 +980,8 @@ availableSecrets:
env: POSTGRES_USER
- versionName: projects/$PROJECT_ID/secrets/postgres_pass/versions/latest
env: POSTGRES_PASS
- versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest
env: CLIENT_ID
- versionName: projects/$PROJECT_ID/secrets/neo4j_user/versions/latest
env: NEO4J_USER
- versionName: projects/$PROJECT_ID/secrets/neo4j_pass/versions/latest

View File

@@ -77,24 +77,11 @@ run_orch_test() {
setup_orch_table
cd "$orch_dir"
echo "--- Debugging npm config for $orch_name ---"
npm config list
echo "--- Active Registry for $orch_name ---"
npm config get registry
echo "--- Inspecting .npmrc files ---"
[ -f ".npmrc" ] && echo "Local .npmrc content:" && cat .npmrc
[ -f "$HOME/.npmrc" ] && echo "Global .npmrc content:" && cat "$HOME/.npmrc"
export GPKG_DEBUG=1
echo "Installing dependencies for $orch_name..."
if [ -f "package-lock.json" ]; then
npm ci --loglevel verbose
npm ci
else
npm install --loglevel verbose
npm install
fi
cd ..

View File

@@ -92,11 +92,11 @@ implementation](https://github.com/googleapis/genai-toolbox/blob/main/internal/s
`newdb.go`. Create a `Config` struct to include all the necessary parameters
for connecting to the database (e.g., host, port, username, password, database
name) and a `Source` struct to store necessary parameters for tools (e.g.,
Name, Kind, connection object, additional config).
Name, Type, connection object, additional config).
* **Implement the
[`SourceConfig`](https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/internal/sources/sources.go#L57)
interface**. This interface requires two methods:
* `SourceConfigKind() string`: Returns a unique string identifier for your
* `SourceConfigType() string`: Returns a unique string identifier for your
data source (e.g., `"newdb"`).
* `Initialize(ctx context.Context, tracer trace.Tracer) (Source, error)`:
Creates a new instance of your data source and establishes a connection to
@@ -104,7 +104,7 @@ implementation](https://github.com/googleapis/genai-toolbox/blob/main/internal/s
* **Implement the
[`Source`](https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/internal/sources/sources.go#L63)
interface**. This interface requires one method:
* `SourceKind() string`: Returns the same string identifier as `SourceConfigKind()`.
* `SourceType() string`: Returns the same string identifier as `SourceConfigType()`.
* **Implement `init()`** to register the new Source.
* **Implement Unit Tests** in a file named `newdb_test.go`.
@@ -126,7 +126,7 @@ tools.
* **Implement the
[`ToolConfig`](https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/internal/tools/tools.go#L61)
interface**. This interface requires one method:
* `ToolConfigKind() string`: Returns a unique string identifier for your tool
* `ToolConfigType() string`: Returns a unique string identifier for your tool
(e.g., `"newdb-tool"`).
* `Initialize(sources map[string]Source) (Tool, error)`: Creates a new
instance of your tool and validates that it can connect to the specified
@@ -243,7 +243,7 @@ resources.
| style | Update src code, with only formatting and whitespace updates (e.g. code formatter or linter changes). |
Pull requests should always add scope whenever possible. The scope is
formatted as `<scope-type>/<scope-kind>` (e.g., `sources/postgres`, or
formatted as `<scope-resource>/<scope-type>` (e.g., `sources/postgres`, or
`tools/mssql-sql`).
Ideally, **each PR covers only one scope**, if this is

View File

@@ -954,7 +954,7 @@ For more details on configuring different types of sources, see the
### Tools
The `tools` section of a `tools.yaml` define the actions an agent can take: what
kind of tool it is, which source(s) it affects, what parameters it uses, etc.
type of tool it is, which source(s) it affects, what parameters it uses, etc.
```yaml
tools:

View File

@@ -15,6 +15,7 @@
package cmd
import (
"bytes"
"context"
_ "embed"
"fmt"
@@ -98,7 +99,6 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlgetinstances"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistdatabases"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistinstances"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlrestorebackup"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlwaitforoperation"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance"
@@ -386,7 +386,6 @@ func NewCommand(opts ...Option) *Command {
// TODO: Insecure by default. Might consider updating this for v1.0.0
flags.StringSliceVar(&cmd.cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.")
flags.StringSliceVar(&cmd.cfg.AllowedHosts, "allowed-hosts", []string{"*"}, "Specifies a list of hosts permitted to access this server. Defaults to '*'.")
flags.StringSliceVar(&cmd.cfg.UserAgentMetadata, "user-agent-metadata", []string{}, "Appends additional metadata to the User-Agent.")
// wrap RunE command so that we have access to original Command object
cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) }
@@ -396,7 +395,6 @@ func NewCommand(opts ...Option) *Command {
type ToolsFile struct {
Sources server.SourceConfigs `yaml:"sources"`
AuthSources server.AuthServiceConfigs `yaml:"authSources"` // Deprecated: Kept for compatibility.
AuthServices server.AuthServiceConfigs `yaml:"authServices"`
EmbeddingModels server.EmbeddingModelConfigs `yaml:"embeddingModels"`
Tools server.ToolConfigs `yaml:"tools"`
@@ -427,6 +425,106 @@ func parseEnv(input string) (string, error) {
return output, err
}
func convertToolsFile(ctx context.Context, raw []byte) ([]byte, error) {
var input yaml.MapSlice
decoder := yaml.NewDecoder(bytes.NewReader(raw), yaml.UseOrderedMap())
if err := decoder.Decode(&input); err != nil {
return nil, err
}
// Convert raw MapSlice to a helper map for quick lookup
// while keeping the values as MapSlices to preserve internal order
resourceOrder := []string{}
lookup := make(map[string]yaml.MapSlice)
for _, item := range input {
key, ok := item.Key.(string)
if !ok {
return nil, fmt.Errorf("unexpected non-string key in input: %v", item.Key)
}
if slice, ok := item.Value.(yaml.MapSlice); ok {
// convert authSources to authServices
if key == "authSources" {
key = "authServices"
}
// works even if lookup[key] is nil
lookup[key] = append(lookup[key], slice...)
// preserving the resource's order of original toolsFile
if !slices.Contains(resourceOrder, key) {
resourceOrder = append(resourceOrder, key)
}
} else {
// toolsfile is already v2
if key == "kind" {
return raw, nil
}
return nil, fmt.Errorf("'%s' is not a map", key)
}
}
// convert to tools file v2
var buf bytes.Buffer
encoder := yaml.NewEncoder(&buf)
for _, kind := range resourceOrder {
data, exists := lookup[kind]
if !exists {
// if this is skipped for all keys, the tools file is in v2
continue
}
// Transform each entry
for _, entry := range data {
entryName, ok := entry.Key.(string)
if !ok {
return nil, fmt.Errorf("unexpected non-string key for entry in '%s': %v", kind, entry.Key)
}
entryBody := ProcessValue(entry.Value, kind == "toolsets")
transformed := yaml.MapSlice{
{Key: "kind", Value: kind},
{Key: "name", Value: entryName},
}
// Merge the transformed body into our result
if bodySlice, ok := entryBody.(yaml.MapSlice); ok {
transformed = append(transformed, bodySlice...)
} else {
return nil, fmt.Errorf("unable to convert entryBody to MapSlice")
}
if err := encoder.Encode(transformed); err != nil {
return nil, err
}
}
}
return buf.Bytes(), nil
}
// ProcessValue recursively looks for MapSlices to rename 'kind' -> 'type'
func ProcessValue(v any, isToolset bool) any {
switch val := v.(type) {
case yaml.MapSlice:
for i := range val {
// Perform renaming
if val[i].Key == "kind" {
val[i].Key = "type"
}
// Recursive call for nested values (e.g., nested objects or lists)
val[i].Value = ProcessValue(val[i].Value, false)
}
return val
case []any:
// Process lists: If it's a toolset top-level list, wrap it.
if isToolset {
return yaml.MapSlice{{Key: "tools", Value: val}}
}
// Otherwise, recurse into list items (to catch nested objects)
for i := range val {
val[i] = ProcessValue(val[i], false)
}
return val
default:
return val
}
}
// parseToolsFile parses the provided yaml into appropriate configs.
func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
var toolsFile ToolsFile
@@ -437,8 +535,13 @@ func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
}
raw = []byte(output)
raw, err = convertToolsFile(ctx, raw)
if err != nil {
return toolsFile, fmt.Errorf("error converting tools file: %s", err)
}
// Parse contents
err = yaml.UnmarshalContext(ctx, raw, &toolsFile, yaml.Strict())
toolsFile.Sources, toolsFile.AuthServices, toolsFile.EmbeddingModels, toolsFile.Tools, toolsFile.Toolsets, toolsFile.Prompts, err = server.UnmarshalResourceConfig(ctx, raw)
if err != nil {
return toolsFile, err
}
@@ -470,18 +573,6 @@ func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) {
}
}
// Check for conflicts and merge authSources (deprecated, but still support)
for name, authSource := range file.AuthSources {
if _, exists := merged.AuthSources[name]; exists {
conflicts = append(conflicts, fmt.Sprintf("authSource '%s' (file #%d)", name, fileIndex+1))
} else {
if merged.AuthSources == nil {
merged.AuthSources = make(server.AuthServiceConfigs)
}
merged.AuthSources[name] = authSource
}
}
// Check for conflicts and merge authServices
for name, authService := range file.AuthServices {
if _, exists := merged.AuthServices[name]; exists {
@@ -957,20 +1048,6 @@ func run(cmd *Command) error {
cmd.cfg.ToolsetConfigs = finalToolsFile.Toolsets
cmd.cfg.PromptConfigs = finalToolsFile.Prompts
authSourceConfigs := finalToolsFile.AuthSources
if authSourceConfigs != nil {
cmd.logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` instead")
for k, v := range authSourceConfigs {
if _, exists := cmd.cfg.AuthServiceConfigs[k]; exists {
errMsg := fmt.Errorf("resource conflict detected: authSource '%s' has the same name as an existing authService. Please rename your authSource", k)
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
cmd.cfg.AuthServiceConfigs[k] = v
}
}
instrumentation, err := telemetry.CreateTelemetryInstrumentation(versionString)
if err != nil {
errMsg := fmt.Errorf("unable to create telemetry instrumentation: %w", err)

View File

@@ -23,12 +23,14 @@ import (
"os"
"path"
"path/filepath"
"reflect"
"regexp"
"runtime"
"strings"
"testing"
"time"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/auth/google"
@@ -70,9 +72,6 @@ func withDefaults(c server.ServerConfig) server.ServerConfig {
if c.AllowedHosts == nil {
c.AllowedHosts = []string{"*"}
}
if c.UserAgentMetadata == nil {
c.UserAgentMetadata = []string{}
}
return c
}
@@ -233,13 +232,6 @@ func TestServerConfigFlags(t *testing.T) {
AllowedHosts: []string{"http://foo.com", "http://bar.com"},
}),
},
{
desc: "user agent metadata",
args: []string{"--user-agent-metadata", "foo,bar"},
want: withDefaults(server.ServerConfig{
UserAgentMetadata: []string{"foo", "bar"},
}),
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
@@ -504,6 +496,309 @@ func TestDefaultLogLevel(t *testing.T) {
}
}
func TestConvertToolsFile(t *testing.T) {
ctx, cancelCtx := context.WithTimeout(context.Background(), time.Minute)
defer cancelCtx()
pr, pw := io.Pipe()
defer pw.Close()
defer pr.Close()
logger, err := log.NewStdLogger(pw, pw, "DEBUG")
if err != nil {
t.Fatalf("failed to setup logger %s", err)
}
ctx = util.WithLogger(ctx, logger)
tcs := []struct {
desc string
in string
want string
isErr bool
errStr string
}{
{
desc: "basic convert",
in: `
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
authServices:
my-google-auth:
kind: google
clientId: testing-id
tools:
example_tool:
kind: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
toolsets:
example_toolset:
- example_tool
prompts:
code_review:
description: ask llm to analyze code quality
messages:
- content: "please review the following code for quality: {{.code}}"
arguments:
- name: code
description: the code to review
embeddingModels:
gemini-model:
kind: gemini
model: gemini-embedding-001
apiKey: some-key
dimension: 768`,
want: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
---
kind: authServices
name: my-google-auth
type: google
clientId: testing-id
---
kind: tools
name: example_tool
type: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
---
kind: toolsets
name: example_toolset
tools:
- example_tool
---
kind: prompts
name: code_review
description: ask llm to analyze code quality
messages:
- content: "please review the following code for quality: {{.code}}"
arguments:
- name: code
description: the code to review
---
kind: embeddingModels
name: gemini-model
type: gemini
model: gemini-embedding-001
apiKey: some-key
dimension: 768`,
},
{
desc: "preserve resource order with grouping",
in: `
tools:
example_tool:
kind: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
authServices:
my-google-auth:
kind: google
clientId: testing-id
toolsets:
example_toolset:
- example_tool
authSources:
my-google-auth:
kind: google
clientId: testing-id`,
want: `
kind: tools
name: example_tool
type: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
---
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
---
kind: authServices
name: my-google-auth
type: google
clientId: testing-id
---
kind: authServices
name: my-google-auth
type: google
clientId: testing-id
---
kind: toolsets
name: example_toolset
tools:
- example_tool`,
},
{
desc: "no convertion needed",
in: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
---
kind: tools
name: example_tool
type: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
---
kind: toolsets
name: example_toolset
tools:
- example_tool`,
want: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
---
kind: tools
name: example_tool
type: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
---
kind: toolsets
name: example_toolset
tools:
- example_tool`,
},
{
desc: "invalid source",
in: `sources: invalid`,
isErr: true,
errStr: "'sources' is not a map",
},
{
desc: "invalid toolset",
in: `toolsets: invalid`,
isErr: true,
errStr: "'toolsets' is not a map",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
output, err := convertToolsFile(ctx, []byte(tc.in))
if tc.isErr {
if err == nil {
t.Fatalf("missing error: %s", tc.errStr)
}
if err.Error() != tc.errStr {
t.Fatalf("invalid error string: got %s, want %s", err, tc.errStr)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
var docs1, docs2 []yaml.MapSlice
if docs1, err = decodeToMapSlice(string(output)); err != nil {
t.Fatalf("error decoding output: %s", err)
}
if docs2, err = decodeToMapSlice(tc.want); err != nil {
t.Fatalf("Error decoding want: %s", err)
}
if !reflect.DeepEqual(docs1, docs2) {
t.Fatalf("incorrect output: got %s, want %s", string(output), tc.want)
}
})
}
}
func decodeToMapSlice(data string) ([]yaml.MapSlice, error) {
// ensures that the order is correct
var docs []yaml.MapSlice
decoder := yaml.NewDecoder(strings.NewReader(data))
for {
var doc yaml.MapSlice
err := decoder.Decode(&doc)
if err == io.EOF {
break
}
if err != nil {
return nil, err
}
docs = append(docs, doc)
}
return docs, nil
}
func TestParseToolFile(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
@@ -515,7 +810,7 @@ func TestParseToolFile(t *testing.T) {
wantToolsFile ToolsFile
}{
{
description: "basic example",
description: "basic example tools file v1",
in: `
sources:
my-pg-instance:
@@ -545,7 +840,7 @@ func TestParseToolFile(t *testing.T) {
Sources: server.SourceConfigs{
"my-pg-instance": cloudsqlpgsrc.Config{
Name: "my-pg-instance",
Kind: cloudsqlpgsrc.SourceKind,
Type: cloudsqlpgsrc.SourceType,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -558,7 +853,7 @@ func TestParseToolFile(t *testing.T) {
Tools: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Kind: "postgres-sql",
Type: "postgres-sql",
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
@@ -578,7 +873,121 @@ func TestParseToolFile(t *testing.T) {
},
},
{
description: "with prompts example",
description: "basic example tools file v2",
in: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
---
kind: authServices
name: my-google-auth
type: google
clientId: testing-id
---
kind: embeddingModels
name: gemini-model
type: gemini
model: gemini-embedding-001
apiKey: some-key
dimension: 768
---
kind: tools
name: example_tool
type: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
---
kind: toolsets
name: example_toolset
tools:
- example_tool
---
kind: prompts
name: code_review
description: ask llm to analyze code quality
messages:
- content: "please review the following code for quality: {{.code}}"
arguments:
- name: code
description: the code to review
`,
wantToolsFile: ToolsFile{
Sources: server.SourceConfigs{
"my-pg-instance": cloudsqlpgsrc.Config{
Name: "my-pg-instance",
Type: cloudsqlpgsrc.SourceType,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
IPType: "public",
Database: "my_db",
User: "my_user",
Password: "my_pass",
},
},
AuthServices: server.AuthServiceConfigs{
"my-google-auth": google.Config{
Name: "my-google-auth",
Type: google.AuthServiceType,
ClientID: "testing-id",
},
},
EmbeddingModels: server.EmbeddingModelConfigs{
"gemini-model": gemini.Config{
Name: "gemini-model",
Type: gemini.EmbeddingModelType,
Model: "gemini-embedding-001",
ApiKey: "some-key",
Dimension: 768,
},
},
Tools: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Type: "postgres-sql",
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
Parameters: []parameters.Parameter{
parameters.NewStringParameter("country", "some description"),
},
AuthRequired: []string{},
},
},
Toolsets: server.ToolsetConfigs{
"example_toolset": tools.ToolsetConfig{
Name: "example_toolset",
ToolNames: []string{"example_tool"},
},
},
Prompts: server.PromptConfigs{
"code_review": custom.Config{
Name: "code_review",
Description: "ask llm to analyze code quality",
Arguments: prompts.Arguments{
{Parameter: parameters.NewStringParameter("code", "the code to review")},
},
Messages: []prompts.Message{
{Role: "user", Content: "please review the following code for quality: {{.code}}"},
},
},
},
},
},
{
description: "only prompts",
in: `
prompts:
my-prompt:
@@ -699,7 +1108,7 @@ func TestParseToolFileWithAuth(t *testing.T) {
Sources: server.SourceConfigs{
"my-pg-instance": cloudsqlpgsrc.Config{
Name: "my-pg-instance",
Kind: cloudsqlpgsrc.SourceKind,
Type: cloudsqlpgsrc.SourceType,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -712,19 +1121,19 @@ func TestParseToolFileWithAuth(t *testing.T) {
AuthServices: server.AuthServiceConfigs{
"my-google-service": google.Config{
Name: "my-google-service",
Kind: google.AuthServiceKind,
Type: google.AuthServiceType,
ClientID: "my-client-id",
},
"other-google-service": google.Config{
Name: "other-google-service",
Kind: google.AuthServiceKind,
Type: google.AuthServiceType,
ClientID: "other-client-id",
},
},
Tools: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Kind: "postgres-sql",
Type: "postgres-sql",
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
@@ -799,7 +1208,7 @@ func TestParseToolFileWithAuth(t *testing.T) {
Sources: server.SourceConfigs{
"my-pg-instance": cloudsqlpgsrc.Config{
Name: "my-pg-instance",
Kind: cloudsqlpgsrc.SourceKind,
Type: cloudsqlpgsrc.SourceType,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -809,22 +1218,22 @@ func TestParseToolFileWithAuth(t *testing.T) {
Password: "my_pass",
},
},
AuthSources: server.AuthServiceConfigs{
AuthServices: server.AuthServiceConfigs{
"my-google-service": google.Config{
Name: "my-google-service",
Kind: google.AuthServiceKind,
Type: google.AuthServiceType,
ClientID: "my-client-id",
},
"other-google-service": google.Config{
Name: "other-google-service",
Kind: google.AuthServiceKind,
Type: google.AuthServiceType,
ClientID: "other-client-id",
},
},
Tools: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Kind: "postgres-sql",
Type: "postgres-sql",
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
@@ -901,7 +1310,7 @@ func TestParseToolFileWithAuth(t *testing.T) {
Sources: server.SourceConfigs{
"my-pg-instance": cloudsqlpgsrc.Config{
Name: "my-pg-instance",
Kind: cloudsqlpgsrc.SourceKind,
Type: cloudsqlpgsrc.SourceType,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -914,19 +1323,19 @@ func TestParseToolFileWithAuth(t *testing.T) {
AuthServices: server.AuthServiceConfigs{
"my-google-service": google.Config{
Name: "my-google-service",
Kind: google.AuthServiceKind,
Type: google.AuthServiceType,
ClientID: "my-client-id",
},
"other-google-service": google.Config{
Name: "other-google-service",
Kind: google.AuthServiceKind,
Type: google.AuthServiceType,
ClientID: "other-client-id",
},
},
Tools: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Kind: "postgres-sql",
Type: "postgres-sql",
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
@@ -1072,7 +1481,7 @@ func TestEnvVarReplacement(t *testing.T) {
Sources: server.SourceConfigs{
"my-http-instance": httpsrc.Config{
Name: "my-http-instance",
Kind: httpsrc.SourceKind,
Type: httpsrc.SourceType,
BaseURL: "http://test_server/",
Timeout: "10s",
DefaultHeaders: map[string]string{"Authorization": "ACTUAL_HEADER"},
@@ -1082,19 +1491,19 @@ func TestEnvVarReplacement(t *testing.T) {
AuthServices: server.AuthServiceConfigs{
"my-google-service": google.Config{
Name: "my-google-service",
Kind: google.AuthServiceKind,
Type: google.AuthServiceType,
ClientID: "ACTUAL_CLIENT_ID",
},
"other-google-service": google.Config{
Name: "other-google-service",
Kind: google.AuthServiceKind,
Type: google.AuthServiceType,
ClientID: "ACTUAL_CLIENT_ID_2",
},
},
Tools: server.ToolConfigs{
"example_tool": http.Config{
Name: "example_tool",
Kind: "http",
Type: "http",
Source: "my-instance",
Method: "GET",
Path: "search?name=alice&pet=cat",
@@ -1503,7 +1912,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{
"cloud_sql_postgres_admin_tools": tools.ToolsetConfig{
Name: "cloud_sql_postgres_admin_tools",
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance", "create_backup", "restore_backup"},
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance", "create_backup"},
},
},
},
@@ -1513,7 +1922,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{
"cloud_sql_mysql_admin_tools": tools.ToolsetConfig{
Name: "cloud_sql_mysql_admin_tools",
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup", "restore_backup"},
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup"},
},
},
},
@@ -1523,7 +1932,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{
"cloud_sql_mssql_admin_tools": tools.ToolsetConfig{
Name: "cloud_sql_mssql_admin_tools",
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup", "restore_backup"},
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup"},
},
},
},

0
cmd/test.db Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,7 @@
"author": "",
"license": "ISC",
"dependencies": {
"@google/adk": "^0.2.4",
"@google/adk": "^0.1.3",
"@toolbox-sdk/adk": "^0.1.5"
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,7 @@
"author": "",
"license": "ISC",
"dependencies": {
"@llamaindex/google": "^0.4.0",
"@llamaindex/google": "^0.3.20",
"@llamaindex/workflow": "^1.1.22",
"@toolbox-sdk/core": "^0.1.2",
"llamaindex": "^0.12.0"

View File

@@ -54,7 +54,6 @@ instance, database and users:
* `create_instance`
* `create_user`
* `clone_instance`
* `restore_backup`
## Install MCP Toolbox
@@ -302,7 +301,6 @@ instances and interacting with your database:
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
* **clone_instance**: Creates a clone of an existing Cloud SQL for SQL Server instance.
* **create_backup**: Creates a backup on a Cloud SQL instance.
* **restore_backup**: Restores a backup of a Cloud SQL instance.
{{< notice note >}}
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs

View File

@@ -54,7 +54,6 @@ database and users:
* `create_instance`
* `create_user`
* `clone_instance`
* `restore_backup`
## Install MCP Toolbox
@@ -302,7 +301,6 @@ instances and interacting with your database:
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
* **clone_instance**: Creates a clone of an existing Cloud SQL for MySQL instance.
* **create_backup**: Creates a backup on a Cloud SQL instance.
* **restore_backup**: Restores a backup of a Cloud SQL instance.
{{< notice note >}}
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs

View File

@@ -54,7 +54,6 @@ instance, database and users:
* `create_instance`
* `create_user`
* `clone_instance`
* `restore_backup`
## Install MCP Toolbox
@@ -302,7 +301,6 @@ instances and interacting with your database:
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
* **clone_instance**: Creates a clone of an existing Cloud SQL for PostgreSQL instance.
* **create_backup**: Creates a backup on a Cloud SQL instance.
* **restore_backup**: Restores a backup of a Cloud SQL instance.
{{< notice note >}}
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs

View File

@@ -20,7 +20,6 @@ The native SDKs can be combined with MCP clients in many cases.
Toolbox currently supports the following versions of MCP specification:
* [2025-11-25](https://modelcontextprotocol.io/specification/2025-11-25)
* [2025-06-18](https://modelcontextprotocol.io/specification/2025-06-18)
* [2025-03-26](https://modelcontextprotocol.io/specification/2025-03-26)
* [2024-11-05](https://modelcontextprotocol.io/specification/2024-11-05)

View File

@@ -207,7 +207,6 @@ You can connect to Toolbox Cloud Run instances directly through the SDK.
{{< tab header="Python" lang="python" >}}
import asyncio
from toolbox_core import ToolboxClient, auth_methods
from toolbox_core.protocol import Protocol
# Replace with the Cloud Run service URL generated in the previous step
URL = "https://cloud-run-url.app"
@@ -218,7 +217,6 @@ async def main():
async with ToolboxClient(
URL,
client_headers={"Authorization": auth_token_provider},
protocol=Protocol.TOOLBOX,
) as toolbox:
toolset = await toolbox.load_toolset()
# ...
@@ -283,5 +281,3 @@ contain the specific error message needed to diagnose the problem.
Manager, it means the Toolbox service account is missing permissions.
- Ensure the `toolbox-identity` service account has the **Secret Manager
Secret Accessor** (`roles/secretmanager.secretAccessor`) IAM role.
- **Cloud Run Connections via IAP:** Currently we do not support Cloud Run connections via [IAP](https://docs.cloud.google.com/iap/docs/concepts-overview). Please disable IAP if you are using it.

View File

@@ -27,7 +27,6 @@ description: >
| | `--ui` | Launches the Toolbox UI web server. | |
| | `--allowed-origins` | Specifies a list of origins permitted to access this server for CORs access. | `*` |
| | `--allowed-hosts` | Specifies a list of hosts permitted to access this server to prevent DNS rebinding attacks. | `*` |
| | `--user-agent-extra` | Appends additional metadata to the User-Agent. | |
| `-v` | `--version` | version for toolbox | |
## Examples

View File

@@ -194,7 +194,6 @@ See [Usage Examples](../reference/cli.md#examples).
* `create_instance`
* `create_user`
* `clone_instance`
* `restore_backup`
* **Tools:**
* `create_instance`: Creates a new Cloud SQL for MySQL instance.
@@ -206,7 +205,6 @@ See [Usage Examples](../reference/cli.md#examples).
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
* `clone_instance`: Creates a clone for an existing Cloud SQL for MySQL instance.
* `create_backup`: Creates a backup on a Cloud SQL instance.
* `restore_backup`: Restores a backup of a Cloud SQL instance.
## Cloud SQL for PostgreSQL
@@ -286,7 +284,6 @@ See [Usage Examples](../reference/cli.md#examples).
* `create_instance`
* `create_user`
* `clone_instance`
* `restore_backup`
* **Tools:**
* `create_instance`: Creates a new Cloud SQL for PostgreSQL instance.
* `get_instance`: Gets information about a Cloud SQL instance.
@@ -297,7 +294,6 @@ See [Usage Examples](../reference/cli.md#examples).
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
* `clone_instance`: Creates a clone for an existing Cloud SQL for PostgreSQL instance.
* `create_backup`: Creates a backup on a Cloud SQL instance.
* `restore_backup`: Restores a backup of a Cloud SQL instance.
## Cloud SQL for SQL Server
@@ -351,7 +347,6 @@ See [Usage Examples](../reference/cli.md#examples).
* `create_instance`
* `create_user`
* `clone_instance`
* `restore_backup`
* **Tools:**
* `create_instance`: Creates a new Cloud SQL for SQL Server instance.
* `get_instance`: Gets information about a Cloud SQL instance.
@@ -362,7 +357,6 @@ See [Usage Examples](../reference/cli.md#examples).
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
* `clone_instance`: Creates a clone for an existing Cloud SQL for SQL Server instance.
* `create_backup`: Creates a backup on a Cloud SQL instance.
* `restore_backup`: Restores a backup of a Cloud SQL instance.
## Dataplex

View File

@@ -12,9 +12,6 @@ aliases:
The `cloud-gemini-data-analytics-query` tool allows you to send natural language questions to the Gemini Data Analytics API and receive structured responses containing SQL queries, natural language answers, and explanations. For details on defining data agent context for database data sources, see the official [documentation](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/data-agent-authored-context-databases).
> [!NOTE]
> Only `alloydb`, `spannerReference`, and `cloudSqlReference` are supported as [datasource references](https://clouddocs.devsite.corp.google.com/gemini/docs/conversational-analytics-api/reference/rest/v1beta/projects.locations.dataAgents#DatasourceReferences).
## Example
```yaml
@@ -44,13 +41,13 @@ tools:
### Usage Flow
When using this tool, a `query` parameter containing a natural language query is provided to the tool (typically by an agent). The tool then interacts with the Gemini Data Analytics API using the context defined in your configuration.
When using this tool, a `prompt` parameter containing a natural language query is provided to the tool (typically by an agent). The tool then interacts with the Gemini Data Analytics API using the context defined in your configuration.
The structure of the response depends on the `generationOptions` configured in your tool definition (e.g., enabling `generateQueryResult` will include the SQL query results).
See [Data Analytics API REST documentation](https://clouddocs.devsite.corp.google.com/gemini/docs/conversational-analytics-api/reference/rest/v1alpha/projects.locations/queryData?rep_location=global) for details.
**Example Input Query:**
**Example Input Prompt:**
```text
How many accounts who have region in Prague are eligible for loans? A3 contains the data of region.

View File

@@ -1,53 +0,0 @@
---
title: cloud-sql-restore-backup
type: docs
weight: 10
description: "Restores a backup of a Cloud SQL instance."
---
The `cloud-sql-restore-backup` tool restores a backup on a Cloud SQL instance using the Cloud SQL Admin API.
{{< notice info dd>}}
This tool uses a `source` of kind `cloud-sql-admin`.
{{< /notice >}}
## Examples
Basic backup restore
```yaml
tools:
backup-restore-basic:
kind: cloud-sql-restore-backup
source: cloud-sql-admin-source
description: "Restores a backup onto the given Cloud SQL instance."
```
## Reference
### Tool Configuration
| **field** | **type** | **required** | **description** |
| -------------- | :------: | :----------: | ------------------------------------------------ |
| kind | string | true | Must be "cloud-sql-restore-backup". |
| source | string | true | The name of the `cloud-sql-admin` source to use. |
| description | string | false | A description of the tool. |
### Tool Inputs
| **parameter** | **type** | **required** | **description** |
| ------------------| :------: | :----------: | -----------------------------------------------------------------------------|
| target_project | string | true | The project ID of the instance to restore the backup onto. |
| target_instance | string | true | The instance to restore the backup onto. Does not include the project ID. |
| backup_id | string | true | The identifier of the backup being restored. |
| source_project | string | false | (Optional) The project ID of the instance that the backup belongs to. |
| source_instance | string | false | (Optional) Cloud SQL instance ID of the instance that the backup belongs to. |
## Usage Notes
- The `backup_id` field can be a BackupRun ID (which will be an int64), backup name, or BackupDR backup name.
- If the `backup_id` field contains a BackupRun ID (i.e. an int64), the optional fields `source_project` and `source_instance` must also be provided.
## See Also
- [Cloud SQL Admin API documentation](https://cloud.google.com/sql/docs/mysql/admin-api)
- [Toolbox Cloud SQL tools documentation](../cloudsql)
- [Cloud SQL Restore API documentation](https://cloud.google.com/sql/docs/mysql/backup-recovery/restoring)

View File

@@ -30,10 +30,6 @@ following config for example:
- name: userNames
type: array
description: The user names to be set.
items:
name: userName # the item name doesn't matter but it has to exist
type: string
description: username
```
If the input is an array of strings `["Alice", "Sid", "Bob"]`, The final command

View File

@@ -21,13 +21,13 @@ import (
// AuthServiceConfig is the interface for configuring authentication services.
type AuthServiceConfig interface {
AuthServiceConfigKind() string
AuthServiceConfigType() string
Initialize() (AuthService, error)
}
// AuthService is the interface for authentication services.
type AuthService interface {
AuthServiceKind() string
AuthServiceType() string
GetName() string
GetClaimsFromHeader(context.Context, http.Header) (map[string]any, error)
ToConfig() AuthServiceConfig

View File

@@ -23,7 +23,7 @@ import (
"google.golang.org/api/idtoken"
)
const AuthServiceKind string = "google"
const AuthServiceType string = "google"
// validate interface
var _ auth.AuthServiceConfig = Config{}
@@ -31,13 +31,13 @@ var _ auth.AuthServiceConfig = Config{}
// Auth service configuration
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
ClientID string `yaml:"clientId" validate:"required"`
}
// Returns the auth service kind
func (cfg Config) AuthServiceConfigKind() string {
return AuthServiceKind
// Returns the auth service type
func (cfg Config) AuthServiceConfigType() string {
return AuthServiceType
}
// Initialize a Google auth service
@@ -55,9 +55,9 @@ type AuthService struct {
Config
}
// Returns the auth service kind
func (a AuthService) AuthServiceKind() string {
return AuthServiceKind
// Returns the auth service type
func (a AuthService) AuthServiceType() string {
return AuthServiceType
}
func (a AuthService) ToConfig() auth.AuthServiceConfig {

View File

@@ -22,12 +22,12 @@ import (
// EmbeddingModelConfig is the interface for configuring embedding models.
type EmbeddingModelConfig interface {
EmbeddingModelConfigKind() string
EmbeddingModelConfigType() string
Initialize(context.Context) (EmbeddingModel, error)
}
type EmbeddingModel interface {
EmbeddingModelKind() string
EmbeddingModelType() string
ToConfig() EmbeddingModelConfig
EmbedParameters(context.Context, []string) ([][]float32, error)
}

View File

@@ -23,22 +23,22 @@ import (
"google.golang.org/genai"
)
const EmbeddingModelKind string = "gemini"
const EmbeddingModelType string = "gemini"
// validate interface
var _ embeddingmodels.EmbeddingModelConfig = Config{}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Model string `yaml:"model" validate:"required"`
ApiKey string `yaml:"apiKey"`
Dimension int32 `yaml:"dimension"`
}
// Returns the embedding model kind
func (cfg Config) EmbeddingModelConfigKind() string {
return EmbeddingModelKind
// Returns the embedding model type
func (cfg Config) EmbeddingModelConfigType() string {
return EmbeddingModelType
}
// Initialize a Gemini embedding model
@@ -69,9 +69,9 @@ type EmbeddingModel struct {
Config
}
// Returns the embedding model kind
func (m EmbeddingModel) EmbeddingModelKind() string {
return EmbeddingModelKind
// Returns the embedding model type
func (m EmbeddingModel) EmbeddingModelType() string {
return EmbeddingModelType
}
func (m EmbeddingModel) ToConfig() embeddingmodels.EmbeddingModelConfig {

View File

@@ -15,9 +15,9 @@
package gemini_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini"
@@ -34,15 +34,15 @@ func TestParseFromYamlGemini(t *testing.T) {
{
desc: "basic example",
in: `
embeddingModels:
my-gemini-model:
kind: gemini
model: text-embedding-004
kind: embeddingModels
name: my-gemini-model
type: gemini
model: text-embedding-004
`,
want: map[string]embeddingmodels.EmbeddingModelConfig{
"my-gemini-model": gemini.Config{
Name: "my-gemini-model",
Kind: gemini.EmbeddingModelKind,
Type: gemini.EmbeddingModelType,
Model: "text-embedding-004",
},
},
@@ -50,17 +50,17 @@ func TestParseFromYamlGemini(t *testing.T) {
{
desc: "full example with optional fields",
in: `
embeddingModels:
complex-gemini:
kind: gemini
model: text-embedding-004
apiKey: "test-api-key"
dimension: 768
kind: embeddingModels
name: complex-gemini
type: gemini
model: text-embedding-004
apiKey: "test-api-key"
dimension: 768
`,
want: map[string]embeddingmodels.EmbeddingModelConfig{
"complex-gemini": gemini.Config{
Name: "complex-gemini",
Kind: gemini.EmbeddingModelKind,
Type: gemini.EmbeddingModelType,
Model: "text-embedding-004",
ApiKey: "test-api-key",
Dimension: 768,
@@ -70,16 +70,13 @@ func TestParseFromYamlGemini(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Models server.EmbeddingModelConfigs `yaml:"embeddingModels"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, got, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Models) {
t.Fatalf("incorrect parse: %v", cmp.Diff(tc.want, got.Models))
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: %v", cmp.Diff(tc.want, got))
}
})
}
@@ -93,32 +90,29 @@ func TestFailParseFromYamlGemini(t *testing.T) {
{
desc: "missing required model field",
in: `
embeddingModels:
bad-model:
kind: gemini
kind: embeddingModels
name: bad-model
type: gemini
`,
// Removed the specific model name from the prefix to match your output
err: "unable to parse as \"gemini\": Key: 'Config.Model' Error:Field validation for 'Model' failed on the 'required' tag",
err: "error unmarshaling embeddingModels: unable to parse as \"bad-model\": Key: 'Config.Model' Error:Field validation for 'Model' failed on the 'required' tag",
},
{
desc: "unknown field",
in: `
embeddingModels:
bad-field:
kind: gemini
model: text-embedding-004
invalid_param: true
kind: embeddingModels
name: bad-field
type: gemini
model: text-embedding-004
invalid_param: true
`,
// Updated to match the specific line-starting format of your error output
err: "unable to parse as \"gemini\": [1:1] unknown field \"invalid_param\"\n> 1 | invalid_param: true\n ^\n 2 | kind: gemini\n 3 | model: text-embedding-004",
err: "error unmarshaling embeddingModels: unable to parse as \"bad-field\": [1:1] unknown field \"invalid_param\"\n> 1 | invalid_param: true\n ^\n 2 | model: text-embedding-004\n 3 | name: bad-field\n 4 | type: gemini",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Models server.EmbeddingModelConfigs `yaml:"embeddingModels"`
}{}
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -46,9 +46,6 @@ tools:
create_backup:
kind: cloud-sql-create-backup
source: cloud-sql-admin-source
restore_backup:
kind: cloud-sql-restore-backup
source: cloud-sql-admin-source
toolsets:
cloud_sql_mssql_admin_tools:
@@ -61,4 +58,3 @@ toolsets:
- wait_for_operation
- clone_instance
- create_backup
- restore_backup

View File

@@ -46,9 +46,6 @@ tools:
create_backup:
kind: cloud-sql-create-backup
source: cloud-sql-admin-source
restore_backup:
kind: cloud-sql-restore-backup
source: cloud-sql-admin-source
toolsets:
cloud_sql_mysql_admin_tools:
@@ -61,4 +58,3 @@ toolsets:
- wait_for_operation
- clone_instance
- create_backup
- restore_backup

View File

@@ -49,9 +49,6 @@ tools:
create_backup:
kind: cloud-sql-create-backup
source: cloud-sql-admin-source
restore_backup:
kind: cloud-sql-restore-backup
source: cloud-sql-admin-source
toolsets:
cloud_sql_postgres_admin_tools:
@@ -65,4 +62,3 @@ toolsets:
- postgres_upgrade_precheck
- clone_instance
- create_backup
- restore_backup

View File

@@ -25,12 +25,12 @@ import (
type Message = prompts.Message
const kind = "custom"
const resourceType = "custom"
// init registers this prompt kind with the prompt framework.
// init registers this prompt type with the prompt framework.
func init() {
if !prompts.Register(kind, newConfig) {
panic(fmt.Sprintf("prompt kind %q already registered", kind))
if !prompts.Register(resourceType, newConfig) {
panic(fmt.Sprintf("prompt type %q already registered", resourceType))
}
}
@@ -56,8 +56,8 @@ type Config struct {
var _ prompts.PromptConfig = Config{}
var _ prompts.Prompt = Prompt{}
func (c Config) PromptConfigKind() string {
return kind
func (c Config) PromptConfigType() string {
return resourceType
}
func (c Config) Initialize() (prompts.Prompt, error) {

View File

@@ -42,7 +42,7 @@ func TestConfig(t *testing.T) {
Arguments: testArgs,
}
// initialize and check kind
// initialize and check type
p, err := cfg.Initialize()
if err != nil {
t.Fatalf("Initialize() failed: %v", err)
@@ -50,8 +50,8 @@ func TestConfig(t *testing.T) {
if p == nil {
t.Fatal("Initialize() returned a nil prompt")
}
if cfg.PromptConfigKind() != "custom" {
t.Errorf("PromptConfigKind() = %q, want %q", cfg.PromptConfigKind(), "custom")
if cfg.PromptConfigType() != "custom" {
t.Errorf("PromptConfigType() = %q, want %q", cfg.PromptConfigType(), "custom")
}
t.Run("Manifest", func(t *testing.T) {

View File

@@ -30,40 +30,40 @@ var promptRegistry = make(map[string]PromptConfigFactory)
// Register allows individual prompt packages to register their configuration
// factory function. This is typically called from an init() function in the
// prompt's package. It associates a 'kind' string with a function that can
// prompt's package. It associates a 'type' string with a function that can
// produce the specific PromptConfig type. It returns true if the registration was
// successful, and false if a prompt with the same kind was already registered.
func Register(kind string, factory PromptConfigFactory) bool {
if _, exists := promptRegistry[kind]; exists {
// Prompt with this kind already exists, do not overwrite.
// successful, and false if a prompt with the same type was already registered.
func Register(resourceType string, factory PromptConfigFactory) bool {
if _, exists := promptRegistry[resourceType]; exists {
// Prompt with this type already exists, do not overwrite.
return false
}
promptRegistry[kind] = factory
promptRegistry[resourceType] = factory
return true
}
// DecodeConfig looks up the registered factory for the given kind and uses it
// DecodeConfig looks up the registered factory for the given type and uses it
// to decode the prompt configuration.
func DecodeConfig(ctx context.Context, kind, name string, decoder *yaml.Decoder) (PromptConfig, error) {
factory, found := promptRegistry[kind]
if !found && kind == "" {
kind = "custom"
factory, found = promptRegistry[kind]
func DecodeConfig(ctx context.Context, resourceType, name string, decoder *yaml.Decoder) (PromptConfig, error) {
factory, found := promptRegistry[resourceType]
if !found && resourceType == "" {
resourceType = "custom"
factory, found = promptRegistry[resourceType]
}
if !found {
return nil, fmt.Errorf("unknown prompt kind: %q", kind)
return nil, fmt.Errorf("unknown prompt type: %q", resourceType)
}
promptConfig, err := factory(ctx, name, decoder)
if err != nil {
return nil, fmt.Errorf("unable to parse prompt %q as kind %q: %w", name, kind, err)
return nil, fmt.Errorf("unable to parse prompt %q as resourceType %q: %w", name, resourceType, err)
}
return promptConfig, nil
}
type PromptConfig interface {
PromptConfigKind() string
PromptConfigType() string
Initialize() (Prompt, error)
}

View File

@@ -29,16 +29,16 @@ import (
type mockPromptConfig struct {
name string
kind string
Type string
}
func (m *mockPromptConfig) PromptConfigKind() string { return m.kind }
func (m *mockPromptConfig) PromptConfigType() string { return m.Type }
func (m *mockPromptConfig) Initialize() (prompts.Prompt, error) { return nil, nil }
var errMockFactory = errors.New("mock factory error")
func mockFactory(ctx context.Context, name string, decoder *yaml.Decoder) (prompts.PromptConfig, error) {
return &mockPromptConfig{name: name, kind: "mockKind"}, nil
return &mockPromptConfig{name: name, Type: "mockType"}, nil
}
func mockErrorFactory(ctx context.Context, name string, decoder *yaml.Decoder) (prompts.PromptConfig, error) {
@@ -50,17 +50,17 @@ func TestRegistry(t *testing.T) {
ctx := context.Background()
t.Run("RegisterAndDecodeSuccess", func(t *testing.T) {
kind := "testKindSuccess"
if !prompts.Register(kind, mockFactory) {
resourceType := "testTypeSuccess"
if !prompts.Register(resourceType, mockFactory) {
t.Fatal("expected registration to succeed")
}
// This should fail because we are registering a duplicate
if prompts.Register(kind, mockFactory) {
if prompts.Register(resourceType, mockFactory) {
t.Fatal("expected duplicate registration to fail")
}
decoder := yaml.NewDecoder(strings.NewReader(""))
config, err := prompts.DecodeConfig(ctx, kind, "testPrompt", decoder)
config, err := prompts.DecodeConfig(ctx, resourceType, "testPrompt", decoder)
if err != nil {
t.Fatalf("expected DecodeConfig to succeed, but got error: %v", err)
}
@@ -69,25 +69,25 @@ func TestRegistry(t *testing.T) {
}
})
t.Run("DecodeUnknownKind", func(t *testing.T) {
t.Run("DecodeUnknownType", func(t *testing.T) {
decoder := yaml.NewDecoder(strings.NewReader(""))
_, err := prompts.DecodeConfig(ctx, "unregisteredKind", "testPrompt", decoder)
_, err := prompts.DecodeConfig(ctx, "unregisteredType", "testPrompt", decoder)
if err == nil {
t.Fatal("expected an error for unknown kind, but got nil")
t.Fatal("expected an error for unknown type, but got nil")
}
if !strings.Contains(err.Error(), "unknown prompt kind") {
t.Errorf("expected error to contain 'unknown prompt kind', but got: %v", err)
if !strings.Contains(err.Error(), "unknown prompt type") {
t.Errorf("expected error to contain 'unknown prompt type', but got: %v", err)
}
})
t.Run("FactoryReturnsError", func(t *testing.T) {
kind := "testKindError"
if !prompts.Register(kind, mockErrorFactory) {
resourceType := "testTypeError"
if !prompts.Register(resourceType, mockErrorFactory) {
t.Fatal("expected registration to succeed")
}
decoder := yaml.NewDecoder(strings.NewReader(""))
_, err := prompts.DecodeConfig(ctx, kind, "testPrompt", decoder)
_, err := prompts.DecodeConfig(ctx, resourceType, "testPrompt", decoder)
if err == nil {
t.Fatal("expected an error from the factory, but got nil")
}
@@ -100,13 +100,13 @@ func TestRegistry(t *testing.T) {
decoder := yaml.NewDecoder(strings.NewReader("description: A test prompt"))
config, err := prompts.DecodeConfig(ctx, "", "testDefaultPrompt", decoder)
if err != nil {
t.Fatalf("expected DecodeConfig with empty kind to succeed, but got error: %v", err)
t.Fatalf("expected DecodeConfig with empty type to succeed, but got error: %v", err)
}
if config == nil {
t.Fatal("expected a non-nil config for default kind")
t.Fatal("expected a non-nil config for default type")
}
if config.PromptConfigKind() != "custom" {
t.Errorf("expected default kind to be 'custom', but got %q", config.PromptConfigKind())
if config.PromptConfigType() != "custom" {
t.Errorf("expected default type to be 'custom', but got %q", config.PromptConfigType())
}
})
}

View File

@@ -14,8 +14,10 @@
package server
import (
"bytes"
"context"
"fmt"
"io"
"strings"
yaml "github.com/goccy/go-yaml"
@@ -64,14 +66,12 @@ type ServerConfig struct {
Stdio bool
// DisableReload indicates if the user has disabled dynamic reloading for Toolbox.
DisableReload bool
// UI indicates if Toolbox UI endpoints (/ui) are available.
// UI indicates if Toolbox UI endpoints (/ui) are available
UI bool
// Specifies a list of origins permitted to access this server.
AllowedOrigins []string
// Specifies a list of hosts permitted to access this server.
// Specifies a list of hosts permitted to access this server
AllowedHosts []string
// UserAgentMetadata specifies additional metadata to append to the User-Agent string.
UserAgentMetadata []string
}
type logFormat string
@@ -126,272 +126,201 @@ func (s *StringLevel) Type() string {
return "stringLevel"
}
// SourceConfigs is a type used to allow unmarshal of the data source config map
type SourceConfigs map[string]sources.SourceConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &SourceConfigs{}
func (c *SourceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(SourceConfigs)
// Parse the 'kind' fields for each source
var raw map[string]util.DelayedUnmarshaler
if err := unmarshal(&raw); err != nil {
return err
}
for name, u := range raw {
// Unmarshal to a general type that ensure it capture all fields
var v map[string]any
if err := u.Unmarshal(&v); err != nil {
return fmt.Errorf("unable to unmarshal %q: %w", name, err)
}
kind, ok := v["kind"]
if !ok {
return fmt.Errorf("missing 'kind' field for source %q", name)
}
kindStr, ok := kind.(string)
if !ok {
return fmt.Errorf("invalid 'kind' field for source %q (must be a string)", name)
}
yamlDecoder, err := util.NewStrictDecoder(v)
if err != nil {
return fmt.Errorf("error creating YAML decoder for source %q: %w", name, err)
}
sourceConfig, err := sources.DecodeConfig(ctx, kindStr, name, yamlDecoder)
if err != nil {
return err
}
(*c)[name] = sourceConfig
}
return nil
}
// AuthServiceConfigs is a type used to allow unmarshal of the data authService config map
type AuthServiceConfigs map[string]auth.AuthServiceConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &AuthServiceConfigs{}
func (c *AuthServiceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(AuthServiceConfigs)
// Parse the 'kind' fields for each authService
var raw map[string]util.DelayedUnmarshaler
if err := unmarshal(&raw); err != nil {
return err
}
for name, u := range raw {
var v map[string]any
if err := u.Unmarshal(&v); err != nil {
return fmt.Errorf("unable to unmarshal %q: %w", name, err)
}
kind, ok := v["kind"]
if !ok {
return fmt.Errorf("missing 'kind' field for %q", name)
}
dec, err := util.NewStrictDecoder(v)
if err != nil {
return fmt.Errorf("error creating decoder: %w", err)
}
switch kind {
case google.AuthServiceKind:
actual := google.Config{Name: name}
if err := dec.DecodeContext(ctx, &actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of auth source", kind)
}
}
return nil
}
// EmbeddingModelConfigs is a type used to allow unmarshal of the embedding model config map
type EmbeddingModelConfigs map[string]embeddingmodels.EmbeddingModelConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &EmbeddingModelConfigs{}
func (c *EmbeddingModelConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(EmbeddingModelConfigs)
// Parse the 'kind' fields for each embedding model
var raw map[string]util.DelayedUnmarshaler
if err := unmarshal(&raw); err != nil {
return err
}
for name, u := range raw {
// Unmarshal to a general type that ensure it capture all fields
var v map[string]any
if err := u.Unmarshal(&v); err != nil {
return fmt.Errorf("unable to unmarshal embedding model %q: %w", name, err)
}
kind, ok := v["kind"]
if !ok {
return fmt.Errorf("missing 'kind' field for embedding model %q", name)
}
dec, err := util.NewStrictDecoder(v)
if err != nil {
return fmt.Errorf("error creating decoder: %w", err)
}
switch kind {
case gemini.EmbeddingModelKind:
actual := gemini.Config{Name: name}
if err := dec.DecodeContext(ctx, &actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of auth source", kind)
}
}
return nil
}
// ToolConfigs is a type used to allow unmarshal of the tool configs
type ToolConfigs map[string]tools.ToolConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &ToolConfigs{}
func (c *ToolConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(ToolConfigs)
// Parse the 'kind' fields for each source
var raw map[string]util.DelayedUnmarshaler
if err := unmarshal(&raw); err != nil {
return err
}
for name, u := range raw {
var v map[string]any
if err := u.Unmarshal(&v); err != nil {
return fmt.Errorf("unable to unmarshal %q: %w", name, err)
}
// `authRequired` and `useClientOAuth` cannot be specified together
if v["authRequired"] != nil && v["useClientOAuth"] == true {
return fmt.Errorf("`authRequired` and `useClientOAuth` are mutually exclusive. Choose only one authentication method")
}
// Make `authRequired` an empty list instead of nil for Tool manifest
if v["authRequired"] == nil {
v["authRequired"] = []string{}
}
kindVal, ok := v["kind"]
if !ok {
return fmt.Errorf("missing 'kind' field for tool %q", name)
}
kindStr, ok := kindVal.(string)
if !ok {
return fmt.Errorf("invalid 'kind' field for tool %q (must be a string)", name)
}
yamlDecoder, err := util.NewStrictDecoder(v)
if err != nil {
return fmt.Errorf("error creating YAML decoder for tool %q: %w", name, err)
}
toolCfg, err := tools.DecodeConfig(ctx, kindStr, name, yamlDecoder)
if err != nil {
return err
}
(*c)[name] = toolCfg
}
return nil
}
// ToolsetConfigs is a type used to allow unmarshal of the toolset configs
type ToolsetConfigs map[string]tools.ToolsetConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &ToolsetConfigs{}
func (c *ToolsetConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(ToolsetConfigs)
var raw map[string][]string
if err := unmarshal(&raw); err != nil {
return err
}
for name, toolList := range raw {
(*c)[name] = tools.ToolsetConfig{Name: name, ToolNames: toolList}
}
return nil
}
// PromptConfigs is a type used to allow unmarshal of the prompt configs
type PromptConfigs map[string]prompts.PromptConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &PromptConfigs{}
func (c *PromptConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(PromptConfigs)
var raw map[string]util.DelayedUnmarshaler
if err := unmarshal(&raw); err != nil {
return err
}
for name, u := range raw {
var v map[string]any
if err := u.Unmarshal(&v); err != nil {
return fmt.Errorf("unable to unmarshal prompt %q: %w", name, err)
}
// Look for the 'kind' field. If it's not present, kindStr will be an
// empty string, which prompts.DecodeConfig will correctly default to "custom".
var kindStr string
if kindVal, ok := v["kind"]; ok {
var isString bool
kindStr, isString = kindVal.(string)
if !isString {
return fmt.Errorf("invalid 'kind' field for prompt %q (must be a string)", name)
}
}
// Create a new, strict decoder for this specific prompt's data.
yamlDecoder, err := util.NewStrictDecoder(v)
if err != nil {
return fmt.Errorf("error creating YAML decoder for prompt %q: %w", name, err)
}
// Use the central registry to decode the prompt based on its kind.
promptCfg, err := prompts.DecodeConfig(ctx, kindStr, name, yamlDecoder)
if err != nil {
return err
}
(*c)[name] = promptCfg
}
return nil
}
// PromptsetConfigs is a type used to allow unmarshal of the PromptsetConfigs configs
type PromptsetConfigs map[string]prompts.PromptsetConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &PromptsetConfigs{}
func UnmarshalResourceConfig(ctx context.Context, raw []byte) (SourceConfigs, AuthServiceConfigs, EmbeddingModelConfigs, ToolConfigs, ToolsetConfigs, PromptConfigs, error) {
// prepare configs map
sourceConfigs := make(map[string]sources.SourceConfig)
authServiceConfigs := make(AuthServiceConfigs)
embeddingModelConfigs := make(EmbeddingModelConfigs)
toolConfigs := make(ToolConfigs)
toolsetConfigs := make(ToolsetConfigs)
promptConfigs := make(PromptConfigs)
// promptset configs is not yet supported
func (c *PromptsetConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(PromptsetConfigs)
decoder := yaml.NewDecoder(bytes.NewReader(raw))
// for loop to unmarshal documents with the `---` separator
for {
var resource map[string]any
if err := decoder.DecodeContext(ctx, &resource); err != nil {
if err == io.EOF {
break
}
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unable to decode YAML document: %w", err)
}
var kind, name string
var ok bool
if kind, ok = resource["kind"].(string); !ok {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("missing 'kind' field or it is not a string")
}
if name, ok = resource["name"].(string); !ok {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("missing 'name' field or it is not a string")
}
// remove 'kind' from map for strict unmarshaling
delete(resource, "kind")
var raw map[string][]string
if err := unmarshal(&raw); err != nil {
return err
switch kind {
case "sources":
c, err := UnmarshalYAMLSourceConfig(ctx, name, resource)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
}
sourceConfigs[name] = c
case "authServices":
c, err := UnmarshalYAMLAuthServiceConfig(ctx, name, resource)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
}
authServiceConfigs[name] = c
case "tools":
c, err := UnmarshalYAMLToolConfig(ctx, name, resource)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
}
toolConfigs[name] = c
case "toolsets":
c, err := UnmarshalYAMLToolsetConfig(ctx, name, resource)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
}
toolsetConfigs[name] = c
case "embeddingModels":
c, err := UnmarshalYAMLEmbeddingModelConfig(ctx, name, resource)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
}
embeddingModelConfigs[name] = c
case "prompts":
c, err := UnmarshalYAMLPromptConfig(ctx, name, resource)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
}
promptConfigs[name] = c
default:
return nil, nil, nil, nil, nil, nil, fmt.Errorf("invalid kind %s", kind)
}
}
for name, promptList := range raw {
(*c)[name] = prompts.PromptsetConfig{Name: name, PromptNames: promptList}
}
return nil
return sourceConfigs, authServiceConfigs, embeddingModelConfigs, toolConfigs, toolsetConfigs, promptConfigs, nil
}
func UnmarshalYAMLSourceConfig(ctx context.Context, name string, r map[string]any) (sources.SourceConfig, error) {
resourceType, ok := r["type"].(string)
if !ok {
return nil, fmt.Errorf("missing 'type' field or it is not a string")
}
dec, err := util.NewStrictDecoder(r)
if err != nil {
return nil, fmt.Errorf("error creating decoder: %w", err)
}
sourceConfig, err := sources.DecodeConfig(ctx, resourceType, name, dec)
if err != nil {
return nil, err
}
return sourceConfig, nil
}
func UnmarshalYAMLAuthServiceConfig(ctx context.Context, name string, r map[string]any) (auth.AuthServiceConfig, error) {
resourceType, ok := r["type"].(string)
if !ok {
return nil, fmt.Errorf("missing 'type' field or it is not a string")
}
if resourceType != google.AuthServiceType {
return nil, fmt.Errorf("%s is not a valid type of auth service", resourceType)
}
dec, err := util.NewStrictDecoder(r)
if err != nil {
return nil, fmt.Errorf("error creating decoder: %s", err)
}
actual := google.Config{Name: name}
if err := dec.DecodeContext(ctx, &actual); err != nil {
return nil, fmt.Errorf("unable to parse as %s: %w", name, err)
}
return actual, nil
}
func UnmarshalYAMLEmbeddingModelConfig(ctx context.Context, name string, r map[string]any) (embeddingmodels.EmbeddingModelConfig, error) {
resourceType, ok := r["type"].(string)
if !ok {
return nil, fmt.Errorf("missing 'type' field or it is not a string")
}
if resourceType != gemini.EmbeddingModelType {
return nil, fmt.Errorf("%s is not a valid type of embedding model", resourceType)
}
dec, err := util.NewStrictDecoder(r)
if err != nil {
return nil, fmt.Errorf("error creating decoder: %s", err)
}
actual := gemini.Config{Name: name}
if err := dec.DecodeContext(ctx, &actual); err != nil {
return nil, fmt.Errorf("unable to parse as %q: %w", name, err)
}
return actual, nil
}
func UnmarshalYAMLToolConfig(ctx context.Context, name string, r map[string]any) (tools.ToolConfig, error) {
resourceType, ok := r["type"].(string)
if !ok {
return nil, fmt.Errorf("missing 'type' field or it is not a string")
}
// `authRequired` and `useClientOAuth` cannot be specified together
if r["authRequired"] != nil && r["useClientOAuth"] == true {
return nil, fmt.Errorf("`authRequired` and `useClientOAuth` are mutually exclusive. Choose only one authentication method")
}
// Make `authRequired` an empty list instead of nil for Tool manifest
if r["authRequired"] == nil {
r["authRequired"] = []string{}
}
dec, err := util.NewStrictDecoder(r)
if err != nil {
return nil, fmt.Errorf("error creating decoder: %s", err)
}
toolCfg, err := tools.DecodeConfig(ctx, resourceType, name, dec)
if err != nil {
return nil, err
}
return toolCfg, nil
}
func UnmarshalYAMLToolsetConfig(ctx context.Context, name string, r map[string]any) (tools.ToolsetConfig, error) {
var toolsetConfig tools.ToolsetConfig
justTools := map[string]any{"tools": r["tools"]}
dec, err := util.NewStrictDecoder(justTools)
if err != nil {
return toolsetConfig, fmt.Errorf("error creating decoder: %s", err)
}
var raw map[string][]string
if err := dec.DecodeContext(ctx, &raw); err != nil {
return toolsetConfig, fmt.Errorf("unable to unmarshal tools: %s", err)
}
return tools.ToolsetConfig{Name: name, ToolNames: raw["tools"]}, nil
}
func UnmarshalYAMLPromptConfig(ctx context.Context, name string, r map[string]any) (prompts.PromptConfig, error) {
// Look for the 'type' field. If it's not present, typeStr will be an
// empty string, which prompts.DecodeConfig will correctly default to "custom".
var resourceType string
if typeVal, ok := r["type"]; ok {
var isString bool
resourceType, isString = typeVal.(string)
if !isString {
return nil, fmt.Errorf("invalid 'type' field for prompt %q (must be a string)", name)
}
}
dec, err := util.NewStrictDecoder(r)
if err != nil {
return nil, fmt.Errorf("error creating decoder: %s", err)
}
// Use the central registry to decode the prompt based on its type.
promptCfg, err := prompts.DecodeConfig(ctx, resourceType, name, dec)
if err != nil {
return nil, err
}
return promptCfg, nil
}

View File

@@ -27,21 +27,19 @@ import (
v20241105 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20241105"
v20250326 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250326"
v20250618 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250618"
v20251125 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20251125"
"github.com/googleapis/genai-toolbox/internal/server/resources"
"github.com/googleapis/genai-toolbox/internal/tools"
)
// LATEST_PROTOCOL_VERSION is the latest version of the MCP protocol supported.
// Update the version used in InitializeResponse when this value is updated.
const LATEST_PROTOCOL_VERSION = v20251125.PROTOCOL_VERSION
const LATEST_PROTOCOL_VERSION = v20250618.PROTOCOL_VERSION
// SUPPORTED_PROTOCOL_VERSIONS is the MCP protocol versions that are supported.
var SUPPORTED_PROTOCOL_VERSIONS = []string{
v20241105.PROTOCOL_VERSION,
v20250326.PROTOCOL_VERSION,
v20250618.PROTOCOL_VERSION,
v20251125.PROTOCOL_VERSION,
}
// InitializeResponse runs capability negotiation and protocol version agreement.
@@ -104,8 +102,6 @@ func NotificationHandler(ctx context.Context, body []byte) error {
// This is the Operation phase of the lifecycle for MCP client-server connections.
func ProcessMethod(ctx context.Context, mcpVersion string, id jsonrpc.RequestId, method string, toolset tools.Toolset, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
switch mcpVersion {
case v20251125.PROTOCOL_VERSION:
return v20251125.ProcessMethod(ctx, id, method, toolset, promptset, resourceMgr, body, header)
case v20250618.PROTOCOL_VERSION:
return v20250618.ProcessMethod(ctx, id, method, toolset, promptset, resourceMgr, body, header)
case v20250326.PROTOCOL_VERSION:

View File

@@ -183,13 +183,6 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
embeddingModels := resourceMgr.GetEmbeddingModelMap()
params, err = tool.EmbedParams(ctx, params, embeddingModels)
if err != nil {
err = fmt.Errorf("error embedding parameters: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil {

View File

@@ -183,13 +183,6 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
embeddingModels := resourceMgr.GetEmbeddingModelMap()
params, err = tool.EmbedParams(ctx, params, embeddingModels)
if err != nil {
err = fmt.Errorf("error embedding parameters: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil {

View File

@@ -176,13 +176,6 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
embeddingModels := resourceMgr.GetEmbeddingModelMap()
params, err = tool.EmbedParams(ctx, params, embeddingModels)
if err != nil {
err = fmt.Errorf("error embedding parameters: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil {

View File

@@ -1,333 +0,0 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package v20251125
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
"github.com/googleapis/genai-toolbox/internal/server/resources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util"
)
// ProcessMethod returns a response for the request.
func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, toolset tools.Toolset, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
switch method {
case PING:
return pingHandler(id)
case TOOLS_LIST:
return toolsListHandler(id, toolset, body)
case TOOLS_CALL:
return toolsCallHandler(ctx, id, resourceMgr, body, header)
case PROMPTS_LIST:
return promptsListHandler(ctx, id, promptset, body)
case PROMPTS_GET:
return promptsGetHandler(ctx, id, resourceMgr, body)
default:
err := fmt.Errorf("invalid method %s", method)
return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err
}
}
// pingHandler handles the "ping" method by returning an empty response.
func pingHandler(id jsonrpc.RequestId) (any, error) {
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: struct{}{},
}, nil
}
func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte) (any, error) {
var req ListToolsRequest
if err := json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp tools list request: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
result := ListToolsResult{
Tools: toolset.McpManifest,
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: result,
}, nil
}
// toolsCallHandler generate a response for tools call.
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
authServices := resourceMgr.GetAuthServiceMap()
// retrieve logger from context
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
var req CallToolRequest
if err = json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp tools call request: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
toolName := req.Params.Name
toolArgument := req.Params.Arguments
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
tool, ok := resourceMgr.GetTool(toolName)
if !ok {
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// Get access token
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
// Check if this specific tool requires the standard authorization header
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
if clientAuth {
if accessToken == "" {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
}
}
// marshal arguments and decode it using decodeJSON instead to prevent loss between floats/int.
aMarshal, err := json.Marshal(toolArgument)
if err != nil {
err = fmt.Errorf("unable to marshal tools argument: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
var data map[string]any
if err = util.DecodeJSON(bytes.NewBuffer(aMarshal), &data); err != nil {
err = fmt.Errorf("unable to decode tools argument: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
// Tool authentication
// claimsFromAuth maps the name of the authservice to the claims retrieved from it.
claimsFromAuth := make(map[string]map[string]any)
// if using stdio, header will be nil and auth will not be supported
if header != nil {
for _, aS := range authServices {
claims, err := aS.GetClaimsFromHeader(ctx, header)
if err != nil {
logger.DebugContext(ctx, err.Error())
continue
}
if claims == nil {
// authService not present in header
continue
}
claimsFromAuth[aS.GetName()] = claims
}
}
// Tool authorization check
verifiedAuthServices := make([]string, len(claimsFromAuth))
i := 0
for k := range claimsFromAuth {
verifiedAuthServices[i] = k
i++
}
// Check if any of the specified auth services is verified
isAuthorized := tool.Authorized(verifiedAuthServices)
if !isAuthorized {
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
logger.DebugContext(ctx, "tool invocation authorized")
params, err := tool.ParseParams(data, claimsFromAuth)
if err != nil {
err = fmt.Errorf("provided parameters were invalid: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
embeddingModels := resourceMgr.GetEmbeddingModelMap()
params, err = tool.EmbedParams(ctx, params, embeddingModels)
if err != nil {
err = fmt.Errorf("error embedding parameters: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil {
errStr := err.Error()
// Missing authService tokens.
if errors.Is(err, util.ErrUnauthorized) {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if clientAuth {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
// Auth error with ADC should raise internal 500 error
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
text := TextContent{
Type: "text",
Text: err.Error(),
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
}, nil
}
content := make([]TextContent, 0)
sliceRes, ok := results.([]any)
if !ok {
sliceRes = []any{results}
}
for _, d := range sliceRes {
text := TextContent{Type: "text"}
dM, err := json.Marshal(d)
if err != nil {
text.Text = fmt.Sprintf("fail to marshal: %s, result: %s", err, d)
} else {
text.Text = string(dM)
}
content = append(content, text)
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: content},
}, nil
}
// promptsListHandler handles the "prompts/list" method.
func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset prompts.Promptset, body []byte) (any, error) {
// retrieve logger from context
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
logger.DebugContext(ctx, "handling prompts/list request")
var req ListPromptsRequest
if err := json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp prompts list request: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
result := ListPromptsResult{
Prompts: promptset.McpManifest,
}
logger.DebugContext(ctx, fmt.Sprintf("returning %d prompts", len(promptset.McpManifest)))
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: result,
}, nil
}
// promptsGetHandler handles the "prompts/get" method.
func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) {
// retrieve logger from context
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
logger.DebugContext(ctx, "handling prompts/get request")
var req GetPromptRequest
if err := json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp prompts/get request: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
promptName := req.Params.Name
logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName))
prompt, ok := resourceMgr.GetPrompt(promptName)
if !ok {
err := fmt.Errorf("prompt with name %q does not exist", promptName)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// Parse the arguments provided in the request.
argValues, err := prompt.ParseArgs(req.Params.Arguments, nil)
if err != nil {
err = fmt.Errorf("invalid arguments for prompt %q: %w", promptName, err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
logger.DebugContext(ctx, fmt.Sprintf("parsed args: %v", argValues))
// Substitute the argument values into the prompt's messages.
substituted, err := prompt.SubstituteParams(argValues)
if err != nil {
err = fmt.Errorf("error substituting params for prompt %q: %w", promptName, err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
// Cast the result to the expected []prompts.Message type.
substitutedMessages, ok := substituted.([]prompts.Message)
if !ok {
err = fmt.Errorf("internal error: SubstituteParams returned unexpected type")
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
logger.DebugContext(ctx, "substituted params successfully")
// Format the response messages into the required structure.
promptMessages := make([]PromptMessage, len(substitutedMessages))
for i, msg := range substitutedMessages {
promptMessages[i] = PromptMessage{
Role: msg.Role,
Content: TextContent{
Type: "text",
Text: msg.Content,
},
}
}
result := GetPromptResult{
Description: prompt.Manifest().Description,
Messages: promptMessages,
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: result,
}, nil
}

View File

@@ -1,219 +0,0 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package v20251125
import (
"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
"github.com/googleapis/genai-toolbox/internal/tools"
)
// SERVER_NAME is the server name used in Implementation.
const SERVER_NAME = "Toolbox"
// PROTOCOL_VERSION is the version of the MCP protocol in this package.
const PROTOCOL_VERSION = "2025-11-25"
// methods that are supported.
const (
PING = "ping"
TOOLS_LIST = "tools/list"
TOOLS_CALL = "tools/call"
PROMPTS_LIST = "prompts/list"
PROMPTS_GET = "prompts/get"
)
/* Empty result */
// EmptyResult represents a response that indicates success but carries no data.
type EmptyResult jsonrpc.Result
/* Pagination */
// Cursor is an opaque token used to represent a cursor for pagination.
type Cursor string
type PaginatedRequest struct {
jsonrpc.Request
Params struct {
// An opaque token representing the current pagination position.
// If provided, the server should return results starting after this cursor.
Cursor Cursor `json:"cursor,omitempty"`
} `json:"params,omitempty"`
}
type PaginatedResult struct {
jsonrpc.Result
// An opaque token representing the pagination position after the last returned result.
// If present, there may be more results available.
NextCursor Cursor `json:"nextCursor,omitempty"`
}
/* Tools */
// Sent from the client to request a list of tools the server has.
type ListToolsRequest struct {
PaginatedRequest
}
// The server's response to a tools/list request from the client.
type ListToolsResult struct {
PaginatedResult
Tools []tools.McpManifest `json:"tools"`
}
// Used by the client to invoke a tool provided by the server.
type CallToolRequest struct {
jsonrpc.Request
Params struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments,omitempty"`
} `json:"params,omitempty"`
}
// The sender or recipient of messages and data in a conversation.
type Role string
const (
RoleUser Role = "user"
RoleAssistant Role = "assistant"
)
// Base for objects that include optional annotations for the client.
// The client can use annotations to inform how objects are used or displayed
type Annotated struct {
Annotations *struct {
// Describes who the intended customer of this object or data is.
// It can include multiple entries to indicate content useful for multiple
// audiences (e.g., `["user", "assistant"]`).
Audience []Role `json:"audience,omitempty"`
// Describes how important this data is for operating the server.
//
// A value of 1 means "most important," and indicates that the data is
// effectively required, while 0 means "least important," and indicates that
// the data is entirely optional.
//
// @TJS-type number
// @minimum 0
// @maximum 1
Priority float64 `json:"priority,omitempty"`
} `json:"annotations,omitempty"`
}
// TextContent represents text provided to or from an LLM.
type TextContent struct {
Annotated
Type string `json:"type"`
// The text content of the message.
Text string `json:"text"`
}
// The server's response to a tool call.
//
// Any errors that originate from the tool SHOULD be reported inside the result
// object, with `isError` set to true, _not_ as an MCP protocol-level error
// response. Otherwise, the LLM would not be able to see that an error occurred
// and self-correct.
//
// However, any errors in _finding_ the tool, an error indicating that the
// server does not support tool calls, or any other exceptional conditions,
// should be reported as an MCP error response.
type CallToolResult struct {
jsonrpc.Result
// Could be either a TextContent, ImageContent, or EmbeddedResources
// For Toolbox, we will only be sending TextContent
Content []TextContent `json:"content"`
// Whether the tool call ended in an error.
// If not set, this is assumed to be false (the call was successful).
//
// Any errors that originate from the tool SHOULD be reported inside the result
// object, with `isError` set to true, _not_ as an MCP protocol-level error
// response. Otherwise, the LLM would not be able to see that an error occurred
// and self-correct.
//
// However, any errors in _finding_ the tool, an error indicating that the
// server does not support tool calls, or any other exceptional conditions,
// should be reported as an MCP error response.
IsError bool `json:"isError,omitempty"`
// An optional JSON object that represents the structured result of the tool call.
StructuredContent map[string]any `json:"structuredContent,omitempty"`
}
// Additional properties describing a Tool to clients.
//
// NOTE: all properties in ToolAnnotations are **hints**.
// They are not guaranteed to provide a faithful description of
// tool behavior (including descriptive properties like `title`).
//
// Clients should never make tool use decisions based on ToolAnnotations
// received from untrusted servers.
type ToolAnnotations struct {
// A human-readable title for the tool.
Title string `json:"title,omitempty"`
// If true, the tool does not modify its environment.
// Default: false
ReadOnlyHint bool `json:"readOnlyHint,omitempty"`
// If true, the tool may perform destructive updates to its environment.
// If false, the tool performs only additive updates.
// (This property is meaningful only when `readOnlyHint == false`)
// Default: true
DestructiveHint bool `json:"destructiveHint,omitempty"`
// If true, calling the tool repeatedly with the same arguments
// will have no additional effect on the its environment.
// (This property is meaningful only when `readOnlyHint == false`)
// Default: false
IdempotentHint bool `json:"idempotentHint,omitempty"`
// If true, this tool may interact with an "open world" of external
// entities. If false, the tool's domain of interaction is closed.
// For example, the world of a web search tool is open, whereas that
// of a memory tool is not.
// Default: true
OpenWorldHint bool `json:"openWorldHint,omitempty"`
}
/* Prompts */
// Sent from the client to request a list of prompts the server has.
type ListPromptsRequest struct {
PaginatedRequest
}
// The server's response to a prompts/list request from the client.
type ListPromptsResult struct {
PaginatedResult
Prompts []prompts.McpManifest `json:"prompts"`
}
// Used by the client to get a prompt provided by the server.
type GetPromptRequest struct {
jsonrpc.Request
Params struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments,omitempty"`
} `json:"params"`
}
// The server's response to a prompts/get request from the client.
type GetPromptResult struct {
jsonrpc.Result
Description string `json:"description,omitempty"`
Messages []PromptMessage `json:"messages"`
}
// Describes a message returned as part of a prompt.
type PromptMessage struct {
Role string `json:"role"`
Content TextContent `json:"content"`
}

View File

@@ -37,7 +37,6 @@ const jsonrpcVersion = "2.0"
const protocolVersion20241105 = "2024-11-05"
const protocolVersion20250326 = "2025-03-26"
const protocolVersion20250618 = "2025-06-18"
const protocolVersion20251125 = "2025-11-25"
const serverName = "Toolbox"
var basicInputSchema = map[string]any{
@@ -486,23 +485,6 @@ func TestMcpEndpoint(t *testing.T) {
},
},
},
{
name: "version 2025-11-25",
protocol: protocolVersion20251125,
idHeader: false,
initWant: map[string]any{
"jsonrpc": "2.0",
"id": "mcp-initialize",
"result": map[string]any{
"protocolVersion": "2025-11-25",
"capabilities": map[string]any{
"tools": map[string]any{"listChanged": false},
"prompts": map[string]any{"listChanged": false},
},
"serverInfo": map[string]any{"name": serverName, "version": fakeVersionString},
},
},
},
}
for _, vtc := range versTestCases {
t.Run(vtc.name, func(t *testing.T) {
@@ -512,7 +494,8 @@ func TestMcpEndpoint(t *testing.T) {
if sessionId != "" {
header["Mcp-Session-Id"] = sessionId
}
if vtc.protocol != protocolVersion20241105 && vtc.protocol != protocolVersion20250326 {
if vtc.protocol == protocolVersion20250618 {
header["MCP-Protocol-Version"] = vtc.protocol
}

View File

@@ -32,7 +32,7 @@ func TestUpdateServer(t *testing.T) {
"example-source": &alloydbpg.Source{
Config: alloydbpg.Config{
Name: "example-alloydb-source",
Kind: "alloydb-postgres",
Type: "alloydb-postgres",
},
},
}
@@ -92,7 +92,7 @@ func TestUpdateServer(t *testing.T) {
"example-source2": &alloydbpg.Source{
Config: alloydbpg.Config{
Name: "example-alloydb-source2",
Kind: "alloydb-postgres",
Type: "alloydb-postgres",
},
},
}

View File

@@ -64,11 +64,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
map[string]prompts.Promptset,
error,
) {
metadataStr := cfg.Version
if len(cfg.UserAgentMetadata) > 0 {
metadataStr += "+" + strings.Join(cfg.UserAgentMetadata, "+")
}
ctx = util.WithUserAgent(ctx, metadataStr)
ctx = util.WithUserAgent(ctx, cfg.Version)
instrumentation, err := util.InstrumentationFromContext(ctx)
if err != nil {
panic(err)
@@ -86,7 +82,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
childCtx, span := instrumentation.Tracer.Start(
ctx,
"toolbox/server/source/init",
trace.WithAttributes(attribute.String("source_kind", sc.SourceConfigKind())),
trace.WithAttributes(attribute.String("source_type", sc.SourceConfigType())),
trace.WithAttributes(attribute.String("source_name", name)),
)
defer span.End()
@@ -114,7 +110,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
_, span := instrumentation.Tracer.Start(
ctx,
"toolbox/server/auth/init",
trace.WithAttributes(attribute.String("auth_kind", sc.AuthServiceConfigKind())),
trace.WithAttributes(attribute.String("auth_type", sc.AuthServiceConfigType())),
trace.WithAttributes(attribute.String("auth_name", name)),
)
defer span.End()
@@ -142,7 +138,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
_, span := instrumentation.Tracer.Start(
ctx,
"toolbox/server/embeddingmodel/init",
trace.WithAttributes(attribute.String("model_kind", ec.EmbeddingModelConfigKind())),
trace.WithAttributes(attribute.String("model_type", ec.EmbeddingModelConfigType())),
trace.WithAttributes(attribute.String("model_name", name)),
)
defer span.End()
@@ -170,7 +166,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
_, span := instrumentation.Tracer.Start(
ctx,
"toolbox/server/tool/init",
trace.WithAttributes(attribute.String("tool_kind", tc.ToolConfigKind())),
trace.WithAttributes(attribute.String("tool_type", tc.ToolConfigType())),
trace.WithAttributes(attribute.String("tool_name", name)),
)
defer span.End()
@@ -239,7 +235,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
_, span := instrumentation.Tracer.Start(
ctx,
"toolbox/server/prompt/init",
trace.WithAttributes(attribute.String("prompt_kind", pc.PromptConfigKind())),
trace.WithAttributes(attribute.String("prompt_type", pc.PromptConfigType())),
trace.WithAttributes(attribute.String("prompt_name", name)),
)
defer span.End()
@@ -308,14 +304,10 @@ func hostCheck(allowedHosts map[string]struct{}) func(http.Handler) http.Handler
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, hasWildcard := allowedHosts["*"]
hostname := r.Host
if host, _, err := net.SplitHostPort(r.Host); err == nil {
hostname = host
}
_, hostIsAllowed := allowedHosts[hostname]
_, hostIsAllowed := allowedHosts[r.Host]
if !hasWildcard && !hostIsAllowed {
// Return 403 Forbidden to block the attack
http.Error(w, "Invalid Host header", http.StatusForbidden)
// Return 400 Bad Request or 403 Forbidden to block the attack
http.Error(w, "Invalid Host header", http.StatusBadRequest)
return
}
next.ServeHTTP(w, r)
@@ -414,11 +406,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
}
allowedHostsMap := make(map[string]struct{}, len(cfg.AllowedHosts))
for _, h := range cfg.AllowedHosts {
hostname := h
if host, _, err := net.SplitHostPort(h); err == nil {
hostname = host
}
allowedHostsMap[hostname] = struct{}{}
allowedHostsMap[h] = struct{}{}
}
r.Use(hostCheck(allowedHostsMap))

View File

@@ -141,7 +141,7 @@ func TestUpdateServer(t *testing.T) {
"example-source": &alloydbpg.Source{
Config: alloydbpg.Config{
Name: "example-alloydb-source",
Kind: "alloydb-postgres",
Type: "alloydb-postgres",
},
},
}

View File

@@ -32,14 +32,14 @@ import (
"google.golang.org/api/option"
)
const SourceKind string = "alloydb-admin"
const SourceType string = "alloydb-admin"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -53,13 +53,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
DefaultProject string `yaml:"defaultProject"`
UseClientOAuth bool `yaml:"useClientOAuth"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
func (r Config) SourceConfigType() string {
return SourceType
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -106,8 +106,8 @@ type Source struct {
Service *alloydbrestapi.Service
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -15,9 +15,9 @@
package alloydbadmin_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -34,14 +34,14 @@ func TestParseFromYamlAlloyDBAdmin(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-alloydb-admin-instance:
kind: alloydb-admin
kind: sources
name: my-alloydb-admin-instance
type: alloydb-admin
`,
want: map[string]sources.SourceConfig{
"my-alloydb-admin-instance": alloydbadmin.Config{
Name: "my-alloydb-admin-instance",
Kind: alloydbadmin.SourceKind,
Type: alloydbadmin.SourceType,
UseClientOAuth: false,
},
},
@@ -49,15 +49,15 @@ func TestParseFromYamlAlloyDBAdmin(t *testing.T) {
{
desc: "use client auth example",
in: `
sources:
my-alloydb-admin-instance:
kind: alloydb-admin
useClientOAuth: true
kind: sources
name: my-alloydb-admin-instance
type: alloydb-admin
useClientOAuth: true
`,
want: map[string]sources.SourceConfig{
"my-alloydb-admin-instance": alloydbadmin.Config{
Name: "my-alloydb-admin-instance",
Kind: alloydbadmin.SourceKind,
Type: alloydbadmin.SourceType,
UseClientOAuth: true,
},
},
@@ -65,16 +65,13 @@ func TestParseFromYamlAlloyDBAdmin(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
@@ -89,30 +86,27 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
sources:
my-alloydb-admin-instance:
kind: alloydb-admin
project: test-project
kind: sources
name: my-alloydb-admin-instance
type: alloydb-admin
project: test-project
`,
err: "unable to parse source \"my-alloydb-admin-instance\" as \"alloydb-admin\": [2:1] unknown field \"project\"\n 1 | kind: alloydb-admin\n> 2 | project: test-project\n ^\n",
err: "error unmarshaling sources: unable to parse source \"my-alloydb-admin-instance\" as \"alloydb-admin\": [2:1] unknown field \"project\"\n 1 | name: my-alloydb-admin-instance\n> 2 | project: test-project\n ^\n 3 | type: alloydb-admin",
},
{
desc: "missing required field",
in: `
sources:
my-alloydb-admin-instance:
useClientOAuth: true
kind: sources
name: my-alloydb-admin-instance
useClientOAuth: true
`,
err: "missing 'kind' field for source \"my-alloydb-admin-instance\"",
err: "error unmarshaling sources: missing 'type' field or it is not a string",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -29,14 +29,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "alloydb-postgres"
const SourceType string = "alloydb-postgres"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -50,7 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Project string `yaml:"project" validate:"required"`
Region string `yaml:"region" validate:"required"`
Cluster string `yaml:"cluster" validate:"required"`
@@ -61,8 +61,8 @@ type Config struct {
Database string `yaml:"database" validate:"required"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
func (r Config) SourceConfigType() string {
return SourceType
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -90,8 +90,8 @@ type Source struct {
Pool *pgxpool.Pool
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -183,7 +183,7 @@ func getConnectionConfig(ctx context.Context, user, pass, dbname string) (string
func initAlloyDBPgConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, cluster, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
defer span.End()
dsn, useIAM, err := getConnectionConfig(ctx, user, pass, dbname)

View File

@@ -15,9 +15,9 @@
package alloydbpg_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -34,21 +34,21 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-pg-instance:
kind: alloydb-postgres
project: my-project
region: my-region
cluster: my-cluster
instance: my-instance
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-pg-instance
type: alloydb-postgres
project: my-project
region: my-region
cluster: my-cluster
instance: my-instance
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
"my-pg-instance": alloydbpg.Config{
Name: "my-pg-instance",
Kind: alloydbpg.SourceKind,
Type: alloydbpg.SourceType,
Project: "my-project",
Region: "my-region",
Cluster: "my-cluster",
@@ -63,22 +63,22 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
{
desc: "public ipType",
in: `
sources:
my-pg-instance:
kind: alloydb-postgres
project: my-project
region: my-region
cluster: my-cluster
instance: my-instance
ipType: Public
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-pg-instance
type: alloydb-postgres
project: my-project
region: my-region
cluster: my-cluster
instance: my-instance
ipType: Public
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
"my-pg-instance": alloydbpg.Config{
Name: "my-pg-instance",
Kind: alloydbpg.SourceKind,
Type: alloydbpg.SourceType,
Project: "my-project",
Region: "my-region",
Cluster: "my-cluster",
@@ -93,22 +93,22 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
{
desc: "private ipType",
in: `
sources:
my-pg-instance:
kind: alloydb-postgres
project: my-project
region: my-region
cluster: my-cluster
instance: my-instance
ipType: private
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-pg-instance
type: alloydb-postgres
project: my-project
region: my-region
cluster: my-cluster
instance: my-instance
ipType: private
database: my_db
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
"my-pg-instance": alloydbpg.Config{
Name: "my-pg-instance",
Kind: alloydbpg.SourceKind,
Type: alloydbpg.SourceType,
Project: "my-project",
Region: "my-region",
Cluster: "my-cluster",
@@ -123,16 +123,13 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
@@ -147,60 +144,56 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "invalid ipType",
in: `
sources:
my-pg-instance:
kind: alloydb-postgres
project: my-project
region: my-region
cluster: my-cluster
instance: my-instance
ipType: fail
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-pg-instance
type: alloydb-postgres
project: my-project
region: my-region
cluster: my-cluster
instance: my-instance
ipType: fail
database: my_db
user: my_user
password: my_pass
`,
err: "unable to parse source \"my-pg-instance\" as \"alloydb-postgres\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"alloydb-postgres\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
},
{
desc: "extra field",
in: `
sources:
my-pg-instance:
kind: alloydb-postgres
project: my-project
region: my-region
cluster: my-cluster
instance: my-instance
database: my_db
user: my_user
password: my_pass
foo: bar
kind: sources
name: my-pg-instance
type: alloydb-postgres
project: my-project
region: my-region
cluster: my-cluster
instance: my-instance
database: my_db
user: my_user
password: my_pass
foo: bar
`,
err: "unable to parse source \"my-pg-instance\" as \"alloydb-postgres\": [3:1] unknown field \"foo\"\n 1 | cluster: my-cluster\n 2 | database: my_db\n> 3 | foo: bar\n ^\n 4 | instance: my-instance\n 5 | kind: alloydb-postgres\n 6 | password: my_pass\n 7 | ",
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"alloydb-postgres\": [3:1] unknown field \"foo\"\n 1 | cluster: my-cluster\n 2 | database: my_db\n> 3 | foo: bar\n ^\n 4 | instance: my-instance\n 5 | name: my-pg-instance\n 6 | password: my_pass\n 7 | ",
},
{
desc: "missing required field",
in: `
sources:
my-pg-instance:
kind: alloydb-postgres
region: my-region
cluster: my-cluster
instance: my-instance
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-pg-instance
type: alloydb-postgres
region: my-region
cluster: my-cluster
instance: my-instance
database: my_db
user: my_user
password: my_pass
`,
err: "unable to parse source \"my-pg-instance\" as \"alloydb-postgres\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"alloydb-postgres\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -41,7 +41,7 @@ import (
"google.golang.org/api/option"
)
const SourceKind string = "bigquery"
const SourceType string = "bigquery"
// CloudPlatformScope is a broad scope for Google Cloud Platform services.
const CloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
@@ -65,8 +65,8 @@ type BigQuerySessionProvider func(ctx context.Context) (*Session, error)
type DataplexClientCreator func(tokenString string) (*dataplexapi.CatalogClient, error)
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -81,7 +81,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
// BigQuery configs
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Project string `yaml:"project" validate:"required"`
Location string `yaml:"location"`
WriteMode string `yaml:"writeMode"`
@@ -119,9 +119,9 @@ func (s *StringOrStringSlice) UnmarshalYAML(unmarshal func(any) error) error {
return fmt.Errorf("cannot unmarshal %T into StringOrStringSlice", v)
}
func (r Config) SourceConfigKind() string {
// Returns BigQuery source kind
return SourceKind
func (r Config) SourceConfigType() string {
// Returns BigQuery source type
return SourceType
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
if r.WriteMode == "" {
@@ -302,9 +302,9 @@ type Session struct {
LastUsed time.Time
}
func (s *Source) SourceKind() string {
// Returns BigQuery Google SQL source kind
return SourceKind
func (s *Source) SourceType() string {
// Returns BigQuery Google SQL source type
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -665,7 +665,7 @@ func initBigQueryConnection(
impersonateServiceAccount string,
scopes []string,
) (*bigqueryapi.Client, *bigqueryrestapi.Service, oauth2.TokenSource, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
defer span.End()
userAgent, err := util.UserAgentFromContext(ctx)
@@ -741,7 +741,7 @@ func initBigQueryConnectionWithOAuthToken(
tokenString string,
wantRestService bool,
) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
defer span.End()
// Construct token source
token := &oauth2.Token{
@@ -801,7 +801,7 @@ func initDataplexConnection(
var clientCreator DataplexClientCreator
var err error
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
defer span.End()
userAgent, err := util.UserAgentFromContext(ctx)

View File

@@ -15,18 +15,18 @@
package bigquery_test
import (
"context"
"math/big"
"reflect"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"go.opentelemetry.io/otel/trace/noop"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace/noop"
)
func TestParseFromYamlBigQuery(t *testing.T) {
@@ -38,15 +38,15 @@ func TestParseFromYamlBigQuery(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-instance:
kind: bigquery
project: my-project
kind: sources
name: my-instance
type: bigquery
project: my-project
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-instance": bigquery.Config{
Name: "my-instance",
Kind: bigquery.SourceKind,
Type: bigquery.SourceType,
Project: "my-project",
Location: "",
WriteMode: "",
@@ -56,17 +56,17 @@ func TestParseFromYamlBigQuery(t *testing.T) {
{
desc: "all fields specified",
in: `
sources:
my-instance:
kind: bigquery
project: my-project
location: asia
writeMode: blocked
kind: sources
name: my-instance
type: bigquery
project: my-project
location: asia
writeMode: blocked
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-instance": bigquery.Config{
Name: "my-instance",
Kind: bigquery.SourceKind,
Type: bigquery.SourceType,
Project: "my-project",
Location: "asia",
WriteMode: "blocked",
@@ -77,17 +77,17 @@ func TestParseFromYamlBigQuery(t *testing.T) {
{
desc: "use client auth example",
in: `
sources:
my-instance:
kind: bigquery
project: my-project
location: us
useClientOAuth: true
kind: sources
name: my-instance
type: bigquery
project: my-project
location: us
useClientOAuth: true
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-instance": bigquery.Config{
Name: "my-instance",
Kind: bigquery.SourceKind,
Type: bigquery.SourceType,
Project: "my-project",
Location: "us",
UseClientOAuth: true,
@@ -97,18 +97,18 @@ func TestParseFromYamlBigQuery(t *testing.T) {
{
desc: "with allowed datasets example",
in: `
sources:
my-instance:
kind: bigquery
project: my-project
location: us
allowedDatasets:
- my_dataset
kind: sources
name: my-instance
type: bigquery
project: my-project
location: us
allowedDatasets:
- my_dataset
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-instance": bigquery.Config{
Name: "my-instance",
Kind: bigquery.SourceKind,
Type: bigquery.SourceType,
Project: "my-project",
Location: "us",
AllowedDatasets: []string{"my_dataset"},
@@ -118,17 +118,17 @@ func TestParseFromYamlBigQuery(t *testing.T) {
{
desc: "with service account impersonation example",
in: `
sources:
my-instance:
kind: bigquery
project: my-project
location: us
impersonateServiceAccount: service-account@my-project.iam.gserviceaccount.com
kind: sources
name: my-instance
type: bigquery
project: my-project
location: us
impersonateServiceAccount: service-account@my-project.iam.gserviceaccount.com
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-instance": bigquery.Config{
Name: "my-instance",
Kind: bigquery.SourceKind,
Type: bigquery.SourceType,
Project: "my-project",
Location: "us",
ImpersonateServiceAccount: "service-account@my-project.iam.gserviceaccount.com",
@@ -138,19 +138,19 @@ func TestParseFromYamlBigQuery(t *testing.T) {
{
desc: "with custom scopes example",
in: `
sources:
my-instance:
kind: bigquery
project: my-project
location: us
scopes:
- https://www.googleapis.com/auth/bigquery
- https://www.googleapis.com/auth/cloud-platform
kind: sources
name: my-instance
type: bigquery
project: my-project
location: us
scopes:
- https://www.googleapis.com/auth/bigquery
- https://www.googleapis.com/auth/cloud-platform
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-instance": bigquery.Config{
Name: "my-instance",
Kind: bigquery.SourceKind,
Type: bigquery.SourceType,
Project: "my-project",
Location: "us",
Scopes: []string{"https://www.googleapis.com/auth/bigquery", "https://www.googleapis.com/auth/cloud-platform"},
@@ -160,17 +160,17 @@ func TestParseFromYamlBigQuery(t *testing.T) {
{
desc: "with max query result rows example",
in: `
sources:
my-instance:
kind: bigquery
project: my-project
location: us
maxQueryResultRows: 10
kind: sources
name: my-instance
type: bigquery
project: my-project
location: us
maxQueryResultRows: 10
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-instance": bigquery.Config{
Name: "my-instance",
Kind: bigquery.SourceKind,
Type: bigquery.SourceType,
Project: "my-project",
Location: "us",
MaxQueryResultRows: 10,
@@ -180,20 +180,15 @@ func TestParseFromYamlBigQuery(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Fatalf("incorrect parse (-want +got):\n%s", diff)
}
})
}
}
func TestFailParseFromYaml(t *testing.T) {
@@ -205,33 +200,29 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
sources:
my-instance:
kind: bigquery
project: my-project
location: us
foo: bar
kind: sources
name: my-instance
type: bigquery
project: my-project
location: us
foo: bar
`,
err: "unable to parse source \"my-instance\" as \"bigquery\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | kind: bigquery\n 3 | location: us\n 4 | project: my-project",
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"bigquery\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | location: us\n 3 | name: my-instance\n 4 | project: my-project\n 5 | ",
},
{
desc: "missing required field",
in: `
sources:
my-instance:
kind: bigquery
location: us
kind: sources
name: my-instance
type: bigquery
location: us
`,
err: "unable to parse source \"my-instance\" as \"bigquery\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"bigquery\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}
@@ -260,7 +251,7 @@ func TestInitialize_MaxQueryResultRows(t *testing.T) {
desc: "default value",
cfg: bigquery.Config{
Name: "test-default",
Kind: bigquery.SourceKind,
Type: bigquery.SourceType,
Project: "test-project",
UseClientOAuth: true,
},
@@ -270,7 +261,7 @@ func TestInitialize_MaxQueryResultRows(t *testing.T) {
desc: "configured value",
cfg: bigquery.Config{
Name: "test-configured",
Kind: bigquery.SourceKind,
Type: bigquery.SourceType,
Project: "test-project",
UseClientOAuth: true,
MaxQueryResultRows: 100,

View File

@@ -27,14 +27,14 @@ import (
"google.golang.org/api/option"
)
const SourceKind string = "bigtable"
const SourceType string = "bigtable"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -48,13 +48,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Project string `yaml:"project" validate:"required"`
Instance string `yaml:"instance" validate:"required"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
func (r Config) SourceConfigType() string {
return SourceType
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -77,8 +77,8 @@ type Source struct {
Client *bigtable.Client
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -179,7 +179,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, configParam param
func initBigtableClient(ctx context.Context, tracer trace.Tracer, name, project, instance string) (*bigtable.Client, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
defer span.End()
// Set up Bigtable data operations client.

View File

@@ -15,9 +15,9 @@
package bigtable_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -34,16 +34,16 @@ func TestParseFromYamlBigtableDb(t *testing.T) {
{
desc: "can configure with a bigtable table",
in: `
sources:
my-bigtable-instance:
kind: bigtable
project: my-project
instance: my-instance
kind: sources
name: my-bigtable-instance
type: bigtable
project: my-project
instance: my-instance
`,
want: map[string]sources.SourceConfig{
"my-bigtable-instance": bigtable.Config{
Name: "my-bigtable-instance",
Kind: bigtable.SourceKind,
Type: bigtable.SourceType,
Project: "my-project",
Instance: "my-instance",
},
@@ -52,16 +52,12 @@ func TestParseFromYamlBigtableDb(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
@@ -77,33 +73,29 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
sources:
my-bigtable-instance:
kind: bigtable
project: my-project
instance: my-instance
foo: bar
kind: sources
name: my-bigtable-instance
type: bigtable
project: my-project
instance: my-instance
foo: bar
`,
err: "unable to parse source \"my-bigtable-instance\" as \"bigtable\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | instance: my-instance\n 3 | kind: bigtable\n 4 | project: my-project",
err: "error unmarshaling sources: unable to parse source \"my-bigtable-instance\" as \"bigtable\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | instance: my-instance\n 3 | name: my-bigtable-instance\n 4 | project: my-project\n 5 | ",
},
{
desc: "missing required field",
in: `
sources:
my-bigtable-instance:
kind: bigtable
project: my-project
kind: sources
name: my-bigtable-instance
type: bigtable
project: my-project
`,
err: "unable to parse source \"my-bigtable-instance\" as \"bigtable\": Key: 'Config.Instance' Error:Field validation for 'Instance' failed on the 'required' tag",
err: "error unmarshaling sources: unable to parse source \"my-bigtable-instance\" as \"bigtable\": Key: 'Config.Instance' Error:Field validation for 'Instance' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -25,11 +25,11 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "cassandra"
const SourceType string = "cassandra"
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -43,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Hosts []string `yaml:"hosts" validate:"required"`
Keyspace string `yaml:"keyspace"`
ProtoVersion int `yaml:"protoVersion"`
@@ -68,9 +68,9 @@ func (c Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return s, nil
}
// SourceConfigKind implements sources.SourceConfig.
func (c Config) SourceConfigKind() string {
return SourceKind
// SourceConfigType implements sources.SourceConfig.
func (c Config) SourceConfigType() string {
return SourceType
}
var _ sources.SourceConfig = Config{}
@@ -89,9 +89,9 @@ func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
// SourceKind implements sources.Source.
func (s *Source) SourceKind() string {
return SourceKind
// SourceType implements sources.Source.
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) RunSQL(ctx context.Context, statement string, params parameters.ParamValues) (any, error) {
@@ -120,7 +120,7 @@ var _ sources.Source = &Source{}
func initCassandraSession(ctx context.Context, tracer trace.Tracer, c Config) (*gocql.Session, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, c.Name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, c.Name)
defer span.End()
// Validate authentication configuration

View File

@@ -15,11 +15,12 @@
package cassandra_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cassandra"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -33,17 +34,17 @@ func TestParseFromYamlCassandra(t *testing.T) {
{
desc: "basic example (without optional fields)",
in: `
sources:
my-cassandra-instance:
kind: cassandra
hosts:
- "my-host1"
- "my-host2"
kind: sources
name: my-cassandra-instance
type: cassandra
hosts:
- "my-host1"
- "my-host2"
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-cassandra-instance": cassandra.Config{
Name: "my-cassandra-instance",
Kind: cassandra.SourceKind,
Type: cassandra.SourceType,
Hosts: []string{"my-host1", "my-host2"},
Username: "",
Password: "",
@@ -59,25 +60,25 @@ func TestParseFromYamlCassandra(t *testing.T) {
{
desc: "with optional fields",
in: `
sources:
my-cassandra-instance:
kind: cassandra
hosts:
- "my-host1"
- "my-host2"
username: "user"
password: "pass"
keyspace: "example_keyspace"
protoVersion: 4
caPath: "path/to/ca.crt"
certPath: "path/to/cert"
keyPath: "path/to/key"
enableHostVerification: true
kind: sources
name: my-cassandra-instance
type: cassandra
hosts:
- "my-host1"
- "my-host2"
username: "user"
password: "pass"
keyspace: "example_keyspace"
protoVersion: 4
caPath: "path/to/ca.crt"
certPath: "path/to/cert"
keyPath: "path/to/key"
enableHostVerification: true
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-cassandra-instance": cassandra.Config{
Name: "my-cassandra-instance",
Kind: cassandra.SourceKind,
Type: cassandra.SourceType,
Hosts: []string{"my-host1", "my-host2"},
Username: "user",
Password: "pass",
@@ -93,16 +94,12 @@ func TestParseFromYamlCassandra(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
@@ -118,33 +115,29 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
sources:
my-cassandra-instance:
kind: cassandra
hosts:
- "my-host"
foo: bar
kind: sources
name: my-cassandra-instance
type: cassandra
hosts:
- "my-host"
foo: bar
`,
err: "unable to parse source \"my-cassandra-instance\" as \"cassandra\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | hosts:\n 3 | - my-host\n 4 | kind: cassandra",
err: "error unmarshaling sources: unable to parse source \"my-cassandra-instance\" as \"cassandra\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | hosts:\n 3 | - my-host\n 4 | name: my-cassandra-instance\n 5 | ",
},
{
desc: "missing required field",
in: `
sources:
my-cassandra-instance:
kind: cassandra
kind: sources
name: my-cassandra-instance
type: cassandra
`,
err: "unable to parse source \"my-cassandra-instance\" as \"cassandra\": Key: 'Config.Hosts' Error:Field validation for 'Hosts' failed on the 'required' tag",
err: "error unmarshaling sources: unable to parse source \"my-cassandra-instance\" as \"cassandra\": Key: 'Config.Hosts' Error:Field validation for 'Hosts' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -28,14 +28,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "clickhouse"
const SourceType string = "clickhouse"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -49,7 +49,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
Database string `yaml:"database" validate:"required"`
@@ -59,8 +59,8 @@ type Config struct {
Secure bool `yaml:"secure"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
func (r Config) SourceConfigType() string {
return SourceType
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -88,8 +88,8 @@ type Source struct {
Pool *sql.DB
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -174,7 +174,7 @@ func validateConfig(protocol string) error {
func initClickHouseConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, protocol string, secure bool) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
defer span.End()
if protocol == "" {

View File

@@ -21,137 +21,113 @@ import (
"github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/testutils"
"go.opentelemetry.io/otel"
)
func TestConfigSourceConfigKind(t *testing.T) {
config := Config{}
if config.SourceConfigKind() != SourceKind {
t.Errorf("Expected %s, got %s", SourceKind, config.SourceConfigKind())
func TestParseFromYamlClickhouse(t *testing.T) {
tcs := []struct {
desc string
in string
want server.SourceConfigs
}{
{
desc: "all fields specified",
in: `
kind: sources
name: test-clickhouse
type: clickhouse
host: localhost
port: "8443"
user: default
password: "mypass"
database: mydb
protocol: https
secure: true
`,
want: map[string]sources.SourceConfig{
"test-clickhouse": Config{
Name: "test-clickhouse",
Type: "clickhouse",
Host: "localhost",
Port: "8443",
User: "default",
Password: "mypass",
Database: "mydb",
Protocol: "https",
Secure: true,
},
},
},
{
desc: "minimal configuration with defaults",
in: `
kind: sources
name: minimal-clickhouse
type: clickhouse
host: 127.0.0.1
port: "8123"
user: testuser
database: testdb
`,
want: map[string]sources.SourceConfig{
"minimal-clickhouse": Config{
Name: "minimal-clickhouse",
Type: "clickhouse",
Host: "127.0.0.1",
Port: "8123",
User: "testuser",
Password: "",
Database: "testdb",
Protocol: "",
Secure: false,
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
}
func TestNewConfig(t *testing.T) {
tests := []struct {
name string
yaml string
expected Config
func TestFailParseFromYaml(t *testing.T) {
tcs := []struct {
desc string
in string
err string
}{
{
name: "all fields specified",
yaml: `
name: test-clickhouse
kind: clickhouse
host: localhost
port: "8443"
user: default
password: "mypass"
database: mydb
protocol: https
secure: true
desc: "extra field",
in: `
kind: sources
name: test-clickhouse
type: clickhouse
host: localhost
foo: bar
`,
expected: Config{
Name: "test-clickhouse",
Kind: "clickhouse",
Host: "localhost",
Port: "8443",
User: "default",
Password: "mypass",
Database: "mydb",
Protocol: "https",
Secure: true,
},
},
{
name: "minimal configuration with defaults",
yaml: `
name: minimal-clickhouse
kind: clickhouse
host: 127.0.0.1
port: "8123"
user: testuser
database: testdb
`,
expected: Config{
Name: "minimal-clickhouse",
Kind: "clickhouse",
Host: "127.0.0.1",
Port: "8123",
User: "testuser",
Password: "",
Database: "testdb",
Protocol: "",
Secure: false,
},
},
{
name: "http protocol",
yaml: `
name: http-clickhouse
kind: clickhouse
host: clickhouse.example.com
port: "8123"
user: analytics
password: "securepass"
database: analytics_db
protocol: http
secure: false
`,
expected: Config{
Name: "http-clickhouse",
Kind: "clickhouse",
Host: "clickhouse.example.com",
Port: "8123",
User: "analytics",
Password: "securepass",
Database: "analytics_db",
Protocol: "http",
Secure: false,
},
},
{
name: "https with secure connection",
yaml: `
name: secure-clickhouse
kind: clickhouse
host: secure.clickhouse.io
port: "8443"
user: secureuser
password: "verysecure"
database: production
protocol: https
secure: true
`,
expected: Config{
Name: "secure-clickhouse",
Kind: "clickhouse",
Host: "secure.clickhouse.io",
Port: "8443",
User: "secureuser",
Password: "verysecure",
Database: "production",
Protocol: "https",
Secure: true,
},
err: "error unmarshaling sources: unable to parse source \"test-clickhouse\" as \"clickhouse\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | host: localhost\n 3 | name: test-clickhouse\n 4 | type: clickhouse",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
decoder := yaml.NewDecoder(strings.NewReader(string(testutils.FormatYaml(tt.yaml))))
config, err := newConfig(context.Background(), tt.expected.Name, decoder)
if err != nil {
t.Fatalf("Failed to create config: %v", err)
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}
clickhouseConfig, ok := config.(Config)
if !ok {
t.Fatalf("Expected Config type, got %T", config)
}
if diff := cmp.Diff(tt.expected, clickhouseConfig); diff != "" {
t.Errorf("Config mismatch (-want +got):\n%s", diff)
errStr := err.Error()
if errStr != tc.err {
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
}
})
}
@@ -167,19 +143,11 @@ func TestNewConfigInvalidYAML(t *testing.T) {
name: "invalid yaml syntax",
yaml: `
name: test-clickhouse
kind: clickhouse
type: clickhouse
host: [invalid
`,
expectError: true,
},
{
name: "missing required fields",
yaml: `
name: test-clickhouse
kind: clickhouse
`,
expectError: false,
},
}
for _, tt := range tests {
@@ -196,10 +164,10 @@ func TestNewConfigInvalidYAML(t *testing.T) {
}
}
func TestSource_SourceKind(t *testing.T) {
func TestSource_SourceType(t *testing.T) {
source := &Source{}
if source.SourceKind() != SourceKind {
t.Errorf("Expected %s, got %s", SourceKind, source.SourceKind())
if source.SourceType() != SourceType {
t.Errorf("Expected %s, got %s", SourceType, source.SourceType())
}
}

View File

@@ -29,15 +29,15 @@ import (
"golang.org/x/oauth2/google"
)
const SourceKind string = "cloud-gemini-data-analytics"
const SourceType string = "cloud-gemini-data-analytics"
const Endpoint string = "https://geminidataanalytics.googleapis.com"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -51,13 +51,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
ProjectID string `yaml:"projectId" validate:"required"`
UseClientOAuth bool `yaml:"useClientOAuth"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
func (r Config) SourceConfigType() string {
return SourceType
}
// Initialize initializes a Gemini Data Analytics Source instance.
@@ -102,8 +102,8 @@ type Source struct {
userAgent string
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -20,7 +20,6 @@ import (
"path/filepath"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -39,15 +38,15 @@ func TestParseFromYamlCloudGDA(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-gda-instance:
kind: cloud-gemini-data-analytics
projectId: test-project-id
`,
kind: sources
name: my-gda-instance
type: cloud-gemini-data-analytics
projectId: test-project-id
`,
want: map[string]sources.SourceConfig{
"my-gda-instance": cloudgda.Config{
Name: "my-gda-instance",
Kind: cloudgda.SourceKind,
Type: cloudgda.SourceType,
ProjectID: "test-project-id",
UseClientOAuth: false,
},
@@ -56,16 +55,16 @@ func TestParseFromYamlCloudGDA(t *testing.T) {
{
desc: "use client auth example",
in: `
sources:
my-gda-instance:
kind: cloud-gemini-data-analytics
projectId: another-project
useClientOAuth: true
kind: sources
name: my-gda-instance
type: cloud-gemini-data-analytics
projectId: another-project
useClientOAuth: true
`,
want: map[string]sources.SourceConfig{
"my-gda-instance": cloudgda.Config{
Name: "my-gda-instance",
Kind: cloudgda.SourceKind,
Type: cloudgda.SourceType,
ProjectID: "another-project",
UseClientOAuth: true,
},
@@ -76,16 +75,12 @@ func TestParseFromYamlCloudGDA(t *testing.T) {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
@@ -101,22 +96,18 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "missing projectId",
in: `
sources:
my-gda-instance:
kind: cloud-gemini-data-analytics
kind: sources
name: my-gda-instance
type: cloud-gemini-data-analytics
`,
err: "unable to parse source \"my-gda-instance\" as \"cloud-gemini-data-analytics\": Key: 'Config.ProjectID' Error:Field validation for 'ProjectID' failed on the 'required' tag",
err: "error unmarshaling sources: unable to parse source \"my-gda-instance\" as \"cloud-gemini-data-analytics\": Key: 'Config.ProjectID' Error:Field validation for 'ProjectID' failed on the 'required' tag",
},
}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}
@@ -153,12 +144,12 @@ func TestInitialize(t *testing.T) {
}{
{
desc: "initialize with ADC",
cfg: cloudgda.Config{Name: "test-gda", Kind: cloudgda.SourceKind, ProjectID: "test-proj"},
cfg: cloudgda.Config{Name: "test-gda", Type: cloudgda.SourceType, ProjectID: "test-proj"},
wantClientOAuth: false,
},
{
desc: "initialize with client OAuth",
cfg: cloudgda.Config{Name: "test-gda-oauth", Kind: cloudgda.SourceKind, ProjectID: "test-proj", UseClientOAuth: true},
cfg: cloudgda.Config{Name: "test-gda-oauth", Type: cloudgda.SourceType, ProjectID: "test-proj", UseClientOAuth: true},
wantClientOAuth: true,
},
}

View File

@@ -34,7 +34,7 @@ import (
"google.golang.org/api/option"
)
const SourceKind string = "cloud-healthcare"
const SourceType string = "cloud-healthcare"
// validate interface
var _ sources.SourceConfig = Config{}
@@ -42,8 +42,8 @@ var _ sources.SourceConfig = Config{}
type HealthcareServiceCreator func(tokenString string) (*healthcare.Service, error)
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -58,7 +58,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
// Healthcare configs
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Project string `yaml:"project" validate:"required"`
Region string `yaml:"region" validate:"required"`
Dataset string `yaml:"dataset" validate:"required"`
@@ -67,8 +67,8 @@ type Config struct {
UseClientOAuth bool `yaml:"useClientOAuth"`
}
func (c Config) SourceConfigKind() string {
return SourceKind
func (c Config) SourceConfigType() string {
return SourceType
}
func (c Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -144,7 +144,7 @@ func newHealthcareServiceCreator(ctx context.Context, tracer trace.Tracer, name
}
func initHealthcareConnectionWithOAuthToken(ctx context.Context, tracer trace.Tracer, name string, userAgent string, tokenString string) (*healthcare.Service, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
defer span.End()
// Construct token source
token := &oauth2.Token{
@@ -162,7 +162,7 @@ func initHealthcareConnectionWithOAuthToken(ctx context.Context, tracer trace.Tr
}
func initHealthcareConnection(ctx context.Context, tracer trace.Tracer, name string) (*healthcare.Service, oauth2.TokenSource, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
defer span.End()
cred, err := google.FindDefaultCredentials(ctx, healthcare.CloudHealthcareScope)
@@ -194,8 +194,8 @@ type Source struct {
allowedDICOMStores map[string]struct{}
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -517,14 +517,14 @@ func (s *Source) RetrieveRenderedDICOMInstance(storeID, study, series, sop strin
return base64String, nil
}
func (s *Source) SearchDICOM(toolKind, storeID, dicomWebPath, tokenStr string, opts []googleapi.CallOption) (any, error) {
func (s *Source) SearchDICOM(toolType, storeID, dicomWebPath, tokenStr string, opts []googleapi.CallOption) (any, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
var resp *http.Response
switch toolKind {
switch toolType {
case "cloud-healthcare-search-dicom-instances":
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...)
case "cloud-healthcare-search-dicom-series":
@@ -532,7 +532,7 @@ func (s *Source) SearchDICOM(toolKind, storeID, dicomWebPath, tokenStr string, o
case "cloud-healthcare-search-dicom-studies":
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForStudies(name, dicomWebPath).Do(opts...)
default:
return nil, fmt.Errorf("incompatible tool kind: %s", toolKind)
return nil, fmt.Errorf("incompatible tool type: %s", toolType)
}
if err != nil {
return nil, fmt.Errorf("failed to search dicom series: %w", err)

View File

@@ -15,11 +15,12 @@
package cloudhealthcare_test
import (
"context"
"testing"
"github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -33,17 +34,17 @@ func TestParseFromYamlCloudHealthcare(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-instance:
kind: cloud-healthcare
project: my-project
region: us-central1
dataset: my-dataset
kind: sources
name: my-instance
type: cloud-healthcare
project: my-project
region: us-central1
dataset: my-dataset
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-instance": cloudhealthcare.Config{
Name: "my-instance",
Kind: cloudhealthcare.SourceKind,
Type: cloudhealthcare.SourceType,
Project: "my-project",
Region: "us-central1",
Dataset: "my-dataset",
@@ -54,18 +55,18 @@ func TestParseFromYamlCloudHealthcare(t *testing.T) {
{
desc: "use client auth example",
in: `
sources:
my-instance:
kind: cloud-healthcare
project: my-project
region: us
dataset: my-dataset
useClientOAuth: true
kind: sources
name: my-instance
type: cloud-healthcare
project: my-project
region: us
dataset: my-dataset
useClientOAuth: true
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-instance": cloudhealthcare.Config{
Name: "my-instance",
Kind: cloudhealthcare.SourceKind,
Type: cloudhealthcare.SourceType,
Project: "my-project",
Region: "us",
Dataset: "my-dataset",
@@ -76,22 +77,22 @@ func TestParseFromYamlCloudHealthcare(t *testing.T) {
{
desc: "with allowed stores example",
in: `
sources:
my-instance:
kind: cloud-healthcare
project: my-project
region: us
dataset: my-dataset
allowedFhirStores:
- my-fhir-store
allowedDicomStores:
- my-dicom-store1
- my-dicom-store2
kind: sources
name: my-instance
type: cloud-healthcare
project: my-project
region: us
dataset: my-dataset
allowedFhirStores:
- my-fhir-store
allowedDicomStores:
- my-dicom-store1
- my-dicom-store2
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-instance": cloudhealthcare.Config{
Name: "my-instance",
Kind: cloudhealthcare.SourceKind,
Type: cloudhealthcare.SourceType,
Project: "my-project",
Region: "us",
Dataset: "my-dataset",
@@ -103,16 +104,12 @@ func TestParseFromYamlCloudHealthcare(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
@@ -127,35 +124,31 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
sources:
my-instance:
kind: cloud-healthcare
project: my-project
region: us-central1
dataset: my-dataset
foo: bar
kind: sources
name: my-instance
type: cloud-healthcare
project: my-project
region: us-central1
dataset: my-dataset
foo: bar
`,
err: "unable to parse source \"my-instance\" as \"cloud-healthcare\": [2:1] unknown field \"foo\"\n 1 | dataset: my-dataset\n> 2 | foo: bar\n ^\n 3 | kind: cloud-healthcare\n 4 | project: my-project\n 5 | region: us-central1",
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"cloud-healthcare\": [2:1] unknown field \"foo\"\n 1 | dataset: my-dataset\n> 2 | foo: bar\n ^\n 3 | name: my-instance\n 4 | project: my-project\n 5 | region: us-central1\n 6 | ",
},
{
desc: "missing required field",
in: `
sources:
my-instance:
kind: cloud-healthcare
project: my-project
region: us-central1
kind: sources
name: my-instance
type: cloud-healthcare
project: my-project
region: us-central1
`,
err: `unable to parse source "my-instance" as "cloud-healthcare": Key: 'Config.Dataset' Error:Field validation for 'Dataset' failed on the 'required' tag`,
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"cloud-healthcare\": Key: 'Config.Dataset' Error:Field validation for 'Dataset' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -29,14 +29,14 @@ import (
monitoring "google.golang.org/api/monitoring/v3"
)
const SourceKind string = "cloud-monitoring"
const SourceType string = "cloud-monitoring"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -50,12 +50,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
UseClientOAuth bool `yaml:"useClientOAuth"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
func (r Config) SourceConfigType() string {
return SourceType
}
// Initialize initializes a Cloud Monitoring Source instance.
@@ -99,8 +99,8 @@ type Source struct {
userAgent string
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -15,9 +15,9 @@
package cloudmonitoring_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -35,14 +35,14 @@ func TestParseFromYamlCloudMonitoring(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-cloud-monitoring-instance:
kind: cloud-monitoring
kind: sources
name: my-cloud-monitoring-instance
type: cloud-monitoring
`,
want: map[string]sources.SourceConfig{
"my-cloud-monitoring-instance": cloudmonitoring.Config{
Name: "my-cloud-monitoring-instance",
Kind: cloudmonitoring.SourceKind,
Type: cloudmonitoring.SourceType,
UseClientOAuth: false,
},
},
@@ -50,15 +50,15 @@ func TestParseFromYamlCloudMonitoring(t *testing.T) {
{
desc: "use client auth example",
in: `
sources:
my-cloud-monitoring-instance:
kind: cloud-monitoring
useClientOAuth: true
kind: sources
name: my-cloud-monitoring-instance
type: cloud-monitoring
useClientOAuth: true
`,
want: map[string]sources.SourceConfig{
"my-cloud-monitoring-instance": cloudmonitoring.Config{
Name: "my-cloud-monitoring-instance",
Kind: cloudmonitoring.SourceKind,
Type: cloudmonitoring.SourceType,
UseClientOAuth: true,
},
},
@@ -68,16 +68,12 @@ func TestParseFromYamlCloudMonitoring(t *testing.T) {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
@@ -93,36 +89,28 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
sources:
my-cloud-monitoring-instance:
kind: cloud-monitoring
project: test-project
kind: sources
name: my-cloud-monitoring-instance
type: cloud-monitoring
project: test-project
`,
err: `unable to parse source "my-cloud-monitoring-instance" as "cloud-monitoring": [2:1] unknown field "project"
1 | kind: cloud-monitoring
> 2 | project: test-project
^
`,
err: "error unmarshaling sources: unable to parse source \"my-cloud-monitoring-instance\" as \"cloud-monitoring\": [2:1] unknown field \"project\"\n 1 | name: my-cloud-monitoring-instance\n> 2 | project: test-project\n ^\n 3 | type: cloud-monitoring",
},
{
desc: "missing required field",
in: `
sources:
my-cloud-monitoring-instance:
useClientOAuth: true
kind: sources
name: my-cloud-monitoring-instance
useClientOAuth: true
`,
err: "missing 'kind' field for source \"my-cloud-monitoring-instance\"",
err: "error unmarshaling sources: missing 'type' field or it is not a string",
},
}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -19,7 +19,6 @@ import (
"fmt"
"net/http"
"regexp"
"strconv"
"strings"
"text/template"
"time"
@@ -35,19 +34,16 @@ import (
sqladmin "google.golang.org/api/sqladmin/v1"
)
const SourceKind string = "cloud-sql-admin"
const SourceType string = "cloud-sql-admin"
var (
targetLinkRegex = regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/databases/([^/]+)`)
backupDRRegex = regexp.MustCompile(`^projects/([^/]+)/locations/([^/]+)/backupVaults/([^/]+)/dataSources/([^/]+)/backups/([^/]+)$`)
)
var targetLinkRegex = regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/databases/([^/]+)`)
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -61,13 +57,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
DefaultProject string `yaml:"defaultProject"`
UseClientOAuth bool `yaml:"useClientOAuth"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
func (r Config) SourceConfigType() string {
return SourceType
}
// Initialize initializes a CloudSQL Admin Source instance.
@@ -114,8 +110,8 @@ type Source struct {
Service *sqladmin.Service
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -378,48 +374,6 @@ func (s *Source) InsertBackupRun(ctx context.Context, project, instance, locatio
return resp, nil
}
func (s *Source) RestoreBackup(ctx context.Context, targetProject, targetInstance, sourceProject, sourceInstance, backupID, accessToken string) (any, error) {
request := &sqladmin.InstancesRestoreBackupRequest{}
// There are 3 scenarios for the backup identifier:
// 1. The identifier is an int64 containing the timestamp of the BackupRun.
// This is used to restore standard backups, and the RestoreBackupContext
// field should be populated with the backup ID and source instance info.
// 2. The identifier is a string of the format
// 'projects/{project-id}/locations/{location}/backupVaults/{backupvault}/dataSources/{datasource}/backups/{backup-uid}'.
// This is used to restore BackupDR backups, and the BackupdrBackup field
// should be populated.
// 3. The identifer is a string of the format
// 'projects/{project-id}/backups/{backup-uid}'. In this case, the Backup
// field should be populated.
if backupRunID, err := strconv.ParseInt(backupID, 10, 64); err == nil {
if sourceProject == "" || targetInstance == "" {
return nil, fmt.Errorf("source project and instance are required when restoring via backup ID")
}
request.RestoreBackupContext = &sqladmin.RestoreBackupContext{
Project: sourceProject,
InstanceId: sourceInstance,
BackupRunId: backupRunID,
}
} else if backupDRRegex.MatchString(backupID) {
request.BackupdrBackup = backupID
} else {
request.Backup = backupID
}
service, err := s.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
resp, err := service.Instances.RestoreBackup(targetProject, targetInstance, request).Do()
if err != nil {
return nil, fmt.Errorf("error restoring backup: %w", err)
}
return resp, nil
}
func generateCloudSQLConnectionMessage(ctx context.Context, source *Source, logger log.Logger, opResponse map[string]any, connectionMessageTemplate string) (string, bool) {
operationType, ok := opResponse["operationType"].(string)
if !ok || operationType != "CREATE_DATABASE" {

View File

@@ -15,9 +15,9 @@
package cloudsqladmin_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -35,14 +35,14 @@ func TestParseFromYamlCloudSQLAdmin(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-cloud-sql-admin-instance:
kind: cloud-sql-admin
kind: sources
name: my-cloud-sql-admin-instance
type: cloud-sql-admin
`,
want: map[string]sources.SourceConfig{
"my-cloud-sql-admin-instance": cloudsqladmin.Config{
Name: "my-cloud-sql-admin-instance",
Kind: cloudsqladmin.SourceKind,
Type: cloudsqladmin.SourceType,
UseClientOAuth: false,
},
},
@@ -50,15 +50,15 @@ func TestParseFromYamlCloudSQLAdmin(t *testing.T) {
{
desc: "use client auth example",
in: `
sources:
my-cloud-sql-admin-instance:
kind: cloud-sql-admin
useClientOAuth: true
kind: sources
name: my-cloud-sql-admin-instance
type: cloud-sql-admin
useClientOAuth: true
`,
want: map[string]sources.SourceConfig{
"my-cloud-sql-admin-instance": cloudsqladmin.Config{
Name: "my-cloud-sql-admin-instance",
Kind: cloudsqladmin.SourceKind,
Type: cloudsqladmin.SourceType,
UseClientOAuth: true,
},
},
@@ -68,16 +68,12 @@ func TestParseFromYamlCloudSQLAdmin(t *testing.T) {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
@@ -93,36 +89,28 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
sources:
my-cloud-sql-admin-instance:
kind: cloud-sql-admin
project: test-project
kind: sources
name: my-cloud-sql-admin-instance
type: cloud-sql-admin
project: test-project
`,
err: `unable to parse source "my-cloud-sql-admin-instance" as "cloud-sql-admin": [2:1] unknown field "project"
1 | kind: cloud-sql-admin
> 2 | project: test-project
^
`,
err: "error unmarshaling sources: unable to parse source \"my-cloud-sql-admin-instance\" as \"cloud-sql-admin\": [2:1] unknown field \"project\"\n 1 | name: my-cloud-sql-admin-instance\n> 2 | project: test-project\n ^\n 3 | type: cloud-sql-admin",
},
{
desc: "missing required field",
in: `
sources:
my-cloud-sql-admin-instance:
useClientOAuth: true
kind: sources
name: my-cloud-sql-admin-instance
useClientOAuth: true
`,
err: "missing 'kind' field for source \"my-cloud-sql-admin-instance\"",
err: "error unmarshaling sources: missing 'type' field or it is not a string",
},
}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -29,14 +29,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "cloud-sql-mssql"
const SourceType string = "cloud-sql-mssql"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -51,7 +51,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
// Cloud SQL MSSQL configs
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Project string `yaml:"project" validate:"required"`
Region string `yaml:"region" validate:"required"`
Instance string `yaml:"instance" validate:"required"`
@@ -62,9 +62,9 @@ type Config struct {
Database string `yaml:"database" validate:"required"`
}
func (r Config) SourceConfigKind() string {
// Returns Cloud SQL MSSQL source kind
return SourceKind
func (r Config) SourceConfigType() string {
// Returns Cloud SQL MSSQL source type
return SourceType
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -94,9 +94,9 @@ type Source struct {
Db *sql.DB
}
func (s *Source) SourceKind() string {
// Returns Cloud SQL MSSQL source kind
return SourceKind
func (s *Source) SourceType() string {
// Returns Cloud SQL MSSQL source type
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -152,7 +152,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
defer span.End()
userAgent, err := util.UserAgentFromContext(ctx)

View File

@@ -15,11 +15,12 @@
package cloudsqlmssql_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -33,20 +34,20 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-instance:
kind: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-instance
type: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-instance": cloudsqlmssql.Config{
Name: "my-instance",
Kind: cloudsqlmssql.SourceKind,
Type: cloudsqlmssql.SourceType,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -60,21 +61,21 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
{
desc: "psc ipType",
in: `
sources:
my-instance:
kind: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
ipType: psc
kind: sources
name: my-instance
type: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
ipType: psc
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-instance": cloudsqlmssql.Config{
Name: "my-instance",
Kind: cloudsqlmssql.SourceKind,
Type: cloudsqlmssql.SourceType,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -88,21 +89,21 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
{
desc: "with deprecated ipAddress",
in: `
sources:
my-instance:
kind: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
ipAddress: random
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-instance
type: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
ipAddress: random
database: my_db
user: my_user
password: my_pass
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-instance": cloudsqlmssql.Config{
Name: "my-instance",
Kind: cloudsqlmssql.SourceKind,
Type: cloudsqlmssql.SourceType,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -117,16 +118,12 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect psarse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect psarse: want %v, got %v", tc.want, got)
}
})
}
@@ -142,57 +139,53 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "invalid ipType",
in: `
sources:
my-instance:
kind: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
ipType: fail
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-instance
type: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
ipType: fail
database: my_db
user: my_user
password: my_pass
`,
err: "unable to parse source \"my-instance\" as \"cloud-sql-mssql\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"cloud-sql-mssql\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
},
{
desc: "extra field",
in: `
sources:
my-instance:
kind: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
foo: bar
kind: sources
name: my-instance
type: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
foo: bar
`,
err: "unable to parse source \"my-instance\" as \"cloud-sql-mssql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | kind: cloud-sql-mssql\n 5 | password: my_pass\n 6 | ",
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"cloud-sql-mssql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | name: my-instance\n 5 | password: my_pass\n 6 | ",
},
{
desc: "missing required field",
in: `
sources:
my-instance:
kind: cloud-sql-mssql
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-instance
type: cloud-sql-mssql
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
`,
err: "unable to parse source \"my-instance\" as \"cloud-sql-mssql\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"cloud-sql-mssql\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -30,14 +30,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "cloud-sql-mysql"
const SourceType string = "cloud-sql-mysql"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -51,7 +51,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Project string `yaml:"project" validate:"required"`
Region string `yaml:"region" validate:"required"`
Instance string `yaml:"instance" validate:"required"`
@@ -61,8 +61,8 @@ type Config struct {
Database string `yaml:"database" validate:"required"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
func (r Config) SourceConfigType() string {
return SourceType
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -90,8 +90,8 @@ type Source struct {
Pool *sql.DB
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -184,7 +184,7 @@ func getConnectionConfig(ctx context.Context, user, pass string) (string, string
func initCloudSQLMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
defer span.End()
// Configure the driver to connect to the database

View File

@@ -15,11 +15,12 @@
package cloudsqlmysql_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -33,20 +34,20 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-mysql-instance:
kind: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-mysql-instance
type: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-mysql-instance": cloudsqlmysql.Config{
Name: "my-mysql-instance",
Kind: cloudsqlmysql.SourceKind,
Type: cloudsqlmysql.SourceType,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -60,21 +61,21 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
{
desc: "public ipType",
in: `
sources:
my-mysql-instance:
kind: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
ipType: Public
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-mysql-instance
type: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
ipType: Public
database: my_db
user: my_user
password: my_pass
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-mysql-instance": cloudsqlmysql.Config{
Name: "my-mysql-instance",
Kind: cloudsqlmysql.SourceKind,
Type: cloudsqlmysql.SourceType,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -88,21 +89,21 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
{
desc: "private ipType",
in: `
sources:
my-mysql-instance:
kind: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
ipType: private
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-mysql-instance
type: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
ipType: private
database: my_db
user: my_user
password: my_pass
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-mysql-instance": cloudsqlmysql.Config{
Name: "my-mysql-instance",
Kind: cloudsqlmysql.SourceKind,
Type: cloudsqlmysql.SourceType,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -116,21 +117,21 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
{
desc: "psc ipType",
in: `
sources:
my-mysql-instance:
kind: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
ipType: psc
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-mysql-instance
type: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
ipType: psc
database: my_db
user: my_user
password: my_pass
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-mysql-instance": cloudsqlmysql.Config{
Name: "my-mysql-instance",
Kind: cloudsqlmysql.SourceKind,
Type: cloudsqlmysql.SourceType,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -144,16 +145,12 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
@@ -169,57 +166,53 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "invalid ipType",
in: `
sources:
my-mysql-instance:
kind: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
ipType: fail
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-mysql-instance
type: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
ipType: fail
database: my_db
user: my_user
password: my_pass
`,
err: "unable to parse source \"my-mysql-instance\" as \"cloud-sql-mysql\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
err: "error unmarshaling sources: unable to parse source \"my-mysql-instance\" as \"cloud-sql-mysql\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
},
{
desc: "extra field",
in: `
sources:
my-mysql-instance:
kind: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
foo: bar
kind: sources
name: my-mysql-instance
type: cloud-sql-mysql
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
foo: bar
`,
err: "unable to parse source \"my-mysql-instance\" as \"cloud-sql-mysql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | kind: cloud-sql-mysql\n 5 | password: my_pass\n 6 | ",
err: "error unmarshaling sources: unable to parse source \"my-mysql-instance\" as \"cloud-sql-mysql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | name: my-mysql-instance\n 5 | password: my_pass\n 6 | ",
},
{
desc: "missing required field",
in: `
sources:
my-mysql-instance:
kind: cloud-sql-mysql
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-mysql-instance
type: cloud-sql-mysql
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
`,
err: "unable to parse source \"my-mysql-instance\" as \"cloud-sql-mysql\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
err: "error unmarshaling sources: unable to parse source \"my-mysql-instance\" as \"cloud-sql-mysql\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -28,14 +28,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "cloud-sql-postgres"
const SourceType string = "cloud-sql-postgres"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -49,7 +49,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Project string `yaml:"project" validate:"required"`
Region string `yaml:"region" validate:"required"`
Instance string `yaml:"instance" validate:"required"`
@@ -59,8 +59,8 @@ type Config struct {
Password string `yaml:"password"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
func (r Config) SourceConfigType() string {
return SourceType
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -88,8 +88,8 @@ type Source struct {
Pool *pgxpool.Pool
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -162,7 +162,7 @@ func getConnectionConfig(ctx context.Context, user, pass, dbname string) (string
func initCloudSQLPgConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
defer span.End()
// Configure the driver to connect to the database

View File

@@ -15,11 +15,12 @@
package cloudsqlpg_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -33,20 +34,20 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-pg-instance": cloudsqlpg.Config{
Name: "my-pg-instance",
Kind: cloudsqlpg.SourceKind,
Type: cloudsqlpg.SourceType,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -60,21 +61,21 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
{
desc: "public ipType",
in: `
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
ipType: Public
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
ipType: Public
database: my_db
user: my_user
password: my_pass
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-pg-instance": cloudsqlpg.Config{
Name: "my-pg-instance",
Kind: cloudsqlpg.SourceKind,
Type: cloudsqlpg.SourceType,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -88,21 +89,21 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
{
desc: "private ipType",
in: `
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
ipType: private
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
ipType: private
database: my_db
user: my_user
password: my_pass
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-pg-instance": cloudsqlpg.Config{
Name: "my-pg-instance",
Kind: cloudsqlpg.SourceKind,
Type: cloudsqlpg.SourceType,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -116,21 +117,21 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
{
desc: "psc ipType",
in: `
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
ipType: psc
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
ipType: psc
database: my_db
user: my_user
password: my_pass
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-pg-instance": cloudsqlpg.Config{
Name: "my-pg-instance",
Kind: cloudsqlpg.SourceKind,
Type: cloudsqlpg.SourceType,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
@@ -144,16 +145,12 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
@@ -169,57 +166,53 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "invalid ipType",
in: `
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
ipType: fail
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
ipType: fail
database: my_db
user: my_user
password: my_pass
`,
err: "unable to parse source \"my-pg-instance\" as \"cloud-sql-postgres\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"cloud-sql-postgres\": ipType invalid: must be one of \"public\", \"private\", or \"psc\"",
},
{
desc: "extra field",
in: `
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
foo: bar
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
foo: bar
`,
err: "unable to parse source \"my-pg-instance\" as \"cloud-sql-postgres\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | kind: cloud-sql-postgres\n 5 | password: my_pass\n 6 | ",
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"cloud-sql-postgres\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | name: my-pg-instance\n 5 | password: my_pass\n 6 | ",
},
{
desc: "missing required field",
in: `
sources:
my-pg-instance:
kind: cloud-sql-postgres
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
`,
err: "unable to parse source \"my-pg-instance\" as \"cloud-sql-postgres\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"cloud-sql-postgres\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -29,14 +29,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "couchbase"
const SourceType string = "couchbase"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -50,7 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
ConnectionString string `yaml:"connectionString" validate:"required"`
Bucket string `yaml:"bucket" validate:"required"`
Scope string `yaml:"scope" validate:"required"`
@@ -66,8 +66,8 @@ type Config struct {
QueryScanConsistency uint `yaml:"queryScanConsistency"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
func (r Config) SourceConfigType() string {
return SourceType
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -96,8 +96,8 @@ type Source struct {
Scope *gocb.Scope
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -15,11 +15,12 @@
package couchbase_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/couchbase"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -33,19 +34,19 @@ func TestParseFromYamlCouchbase(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-couchbase-instance:
kind: couchbase
connectionString: localhost
username: Administrator
password: password
bucket: travel-sample
scope: inventory
kind: sources
name: my-couchbase-instance
type: couchbase
connectionString: localhost
username: Administrator
password: password
bucket: travel-sample
scope: inventory
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-couchbase-instance": couchbase.Config{
Name: "my-couchbase-instance",
Kind: couchbase.SourceKind,
Type: couchbase.SourceType,
ConnectionString: "localhost",
Username: "Administrator",
Password: "password",
@@ -57,24 +58,24 @@ func TestParseFromYamlCouchbase(t *testing.T) {
{
desc: "with TLS configuration",
in: `
sources:
my-couchbase-instance:
kind: couchbase
connectionString: couchbases://localhost
bucket: travel-sample
scope: inventory
clientCert: /path/to/cert.pem
clientKey: /path/to/key.pem
clientCertPassword: password
clientKeyPassword: password
caCert: /path/to/ca.pem
noSslVerify: false
queryScanConsistency: 2
kind: sources
name: my-couchbase-instance
type: couchbase
connectionString: couchbases://localhost
bucket: travel-sample
scope: inventory
clientCert: /path/to/cert.pem
clientKey: /path/to/key.pem
clientCertPassword: password
clientKeyPassword: password
caCert: /path/to/ca.pem
noSslVerify: false
queryScanConsistency: 2
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-couchbase-instance": couchbase.Config{
Name: "my-couchbase-instance",
Kind: couchbase.SourceKind,
Type: couchbase.SourceType,
ConnectionString: "couchbases://localhost",
Bucket: "travel-sample",
Scope: "inventory",
@@ -91,16 +92,12 @@ func TestParseFromYamlCouchbase(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
@@ -115,39 +112,35 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
sources:
my-couchbase-instance:
kind: couchbase
connectionString: localhost
username: Administrator
password: password
bucket: travel-sample
scope: inventory
foo: bar
kind: sources
name: my-couchbase-instance
type: couchbase
connectionString: localhost
username: Administrator
password: password
bucket: travel-sample
scope: inventory
foo: bar
`,
err: "unable to parse source \"my-couchbase-instance\" as \"couchbase\": [3:1] unknown field \"foo\"\n 1 | bucket: travel-sample\n 2 | connectionString: localhost\n> 3 | foo: bar\n ^\n 4 | kind: couchbase\n 5 | password: password\n 6 | scope: inventory\n 7 | ",
err: "error unmarshaling sources: unable to parse source \"my-couchbase-instance\" as \"couchbase\": [3:1] unknown field \"foo\"\n 1 | bucket: travel-sample\n 2 | connectionString: localhost\n> 3 | foo: bar\n ^\n 4 | name: my-couchbase-instance\n 5 | password: password\n 6 | scope: inventory\n 7 | ",
},
{
desc: "missing required field",
in: `
sources:
my-couchbase-instance:
kind: couchbase
username: Administrator
password: password
bucket: travel-sample
scope: inventory
kind: sources
name: my-couchbase-instance
type: couchbase
username: Administrator
password: password
bucket: travel-sample
scope: inventory
`,
err: "unable to parse source \"my-couchbase-instance\" as \"couchbase\": Key: 'Config.ConnectionString' Error:Field validation for 'ConnectionString' failed on the 'required' tag",
err: "error unmarshaling sources: unable to parse source \"my-couchbase-instance\" as \"couchbase\": Key: 'Config.ConnectionString' Error:Field validation for 'ConnectionString' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -29,14 +29,14 @@ import (
"google.golang.org/api/option"
)
const SourceKind string = "dataplex"
const SourceType string = "dataplex"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -51,13 +51,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
// Dataplex configs
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Project string `yaml:"project" validate:"required"`
}
func (r Config) SourceConfigKind() string {
// Returns Dataplex source kind
return SourceKind
func (r Config) SourceConfigType() string {
// Returns Dataplex source type
return SourceType
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -81,9 +81,9 @@ type Source struct {
Client *dataplexapi.CatalogClient
}
func (s *Source) SourceKind() string {
// Returns Dataplex source kind
return SourceKind
func (s *Source) SourceType() string {
// Returns Dataplex source type
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -104,7 +104,7 @@ func initDataplexConnection(
name string,
project string,
) (*dataplexapi.CatalogClient, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
defer span.End()
cred, err := google.FindDefaultCredentials(ctx)

View File

@@ -15,11 +15,12 @@
package dataplex_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/dataplex"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -33,15 +34,15 @@ func TestParseFromYamlDataplex(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-instance:
kind: dataplex
project: my-project
kind: sources
name: my-instance
type: dataplex
project: my-project
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-instance": dataplex.Config{
Name: "my-instance",
Kind: dataplex.SourceKind,
Type: dataplex.SourceType,
Project: "my-project",
},
},
@@ -49,16 +50,12 @@ func TestParseFromYamlDataplex(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
@@ -74,31 +71,27 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
sources:
my-instance:
kind: dataplex
project: my-project
foo: bar
kind: sources
name: my-instance
type: dataplex
project: my-project
foo: bar
`,
err: "unable to parse source \"my-instance\" as \"dataplex\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | kind: dataplex\n 3 | project: my-project",
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"dataplex\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | name: my-instance\n 3 | project: my-project\n 4 | type: dataplex",
},
{
desc: "missing required field",
in: `
sources:
my-instance:
kind: dataplex
kind: sources
name: my-instance
type: dataplex
`,
err: "unable to parse source \"my-instance\" as \"dataplex\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
err: "error unmarshaling sources: unable to parse source \"my-instance\" as \"dataplex\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -30,14 +30,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "dgraph"
const SourceType string = "dgraph"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -67,7 +67,7 @@ type DgraphClient struct {
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
DgraphUrl string `yaml:"dgraphUrl" validate:"required"`
User string `yaml:"user"`
Password string `yaml:"password"`
@@ -75,8 +75,8 @@ type Config struct {
ApiKey string `yaml:"apiKey"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
func (r Config) SourceConfigType() string {
return SourceType
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -103,8 +103,8 @@ type Source struct {
Client *DgraphClient `yaml:"client"`
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -139,7 +139,7 @@ func (s *Source) RunSQL(statement string, params parameters.ParamValues, isQuery
func initDgraphHttpClient(ctx context.Context, tracer trace.Tracer, r Config) (*DgraphClient, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, r.Name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, r.Name)
defer span.End()
if r.DgraphUrl == "" {

View File

@@ -15,11 +15,12 @@
package dgraph_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/dgraph"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -33,19 +34,19 @@ func TestParseFromYamlDgraph(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-dgraph-instance:
kind: dgraph
dgraphUrl: https://localhost:8080
apiKey: abc123
password: pass@123
namespace: 0
user: user123
kind: sources
name: my-dgraph-instance
type: dgraph
dgraphUrl: https://localhost:8080
apiKey: abc123
password: pass@123
namespace: 0
user: user123
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-dgraph-instance": dgraph.Config{
Name: "my-dgraph-instance",
Kind: dgraph.SourceKind,
Type: dgraph.SourceType,
DgraphUrl: "https://localhost:8080",
ApiKey: "abc123",
Password: "pass@123",
@@ -57,15 +58,15 @@ func TestParseFromYamlDgraph(t *testing.T) {
{
desc: "basic example minimal field",
in: `
sources:
my-dgraph-instance:
kind: dgraph
dgraphUrl: https://localhost:8080
kind: sources
name: my-dgraph-instance
type: dgraph
dgraphUrl: https://localhost:8080
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-dgraph-instance": dgraph.Config{
Name: "my-dgraph-instance",
Kind: dgraph.SourceKind,
Type: dgraph.SourceType,
DgraphUrl: "https://localhost:8080",
},
},
@@ -74,16 +75,12 @@ func TestParseFromYamlDgraph(t *testing.T) {
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Sources); diff != "" {
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
@@ -100,31 +97,27 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
sources:
my-dgraph-instance:
kind: dgraph
dgraphUrl: https://localhost:8080
foo: bar
kind: sources
name: my-dgraph-instance
type: dgraph
dgraphUrl: https://localhost:8080
foo: bar
`,
err: "unable to parse source \"my-dgraph-instance\" as \"dgraph\": [2:1] unknown field \"foo\"\n 1 | dgraphUrl: https://localhost:8080\n> 2 | foo: bar\n ^\n 3 | kind: dgraph",
err: "error unmarshaling sources: unable to parse source \"my-dgraph-instance\" as \"dgraph\": [2:1] unknown field \"foo\"\n 1 | dgraphUrl: https://localhost:8080\n> 2 | foo: bar\n ^\n 3 | name: my-dgraph-instance\n 4 | type: dgraph",
},
{
desc: "missing required field",
in: `
sources:
my-dgraph-instance:
kind: dgraph
kind: sources
name: my-dgraph-instance
type: dgraph
`,
err: "unable to parse source \"my-dgraph-instance\" as \"dgraph\": Key: 'Config.DgraphUrl' Error:Field validation for 'DgraphUrl' failed on the 'required' tag",
err: "error unmarshaling sources: unable to parse source \"my-dgraph-instance\" as \"dgraph\": Key: 'Config.DgraphUrl' Error:Field validation for 'DgraphUrl' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -30,14 +30,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "elasticsearch"
const SourceType string = "elasticsearch"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -51,15 +51,15 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Addresses []string `yaml:"addresses" validate:"required"`
Username string `yaml:"username"`
Password string `yaml:"password"`
APIKey string `yaml:"apikey"`
}
func (c Config) SourceConfigKind() string {
return SourceKind
func (c Config) SourceConfigType() string {
return SourceType
}
type EsClient interface {
@@ -139,9 +139,9 @@ func (c Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return s, nil
}
// SourceKind returns the kind string for this source.
func (s *Source) SourceKind() string {
return SourceKind
// SourceType returns the resourceType string for this source.
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -15,13 +15,15 @@
package elasticsearch_test
import (
"context"
"reflect"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/elasticsearch"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
func TestParseFromYamlElasticsearch(t *testing.T) {
@@ -33,17 +35,17 @@ func TestParseFromYamlElasticsearch(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-es-instance:
kind: elasticsearch
addresses:
- http://localhost:9200
apikey: somekey
`,
want: server.SourceConfigs{
kind: sources
name: my-es-instance
type: elasticsearch
addresses:
- http://localhost:9200
apikey: somekey
`,
want: map[string]sources.SourceConfig{
"my-es-instance": elasticsearch.Config{
Name: "my-es-instance",
Kind: elasticsearch.SourceKind,
Type: elasticsearch.SourceType,
Addresses: []string{"http://localhost:9200"},
APIKey: "somekey",
},
@@ -52,20 +54,50 @@ func TestParseFromYamlElasticsearch(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
err := yaml.Unmarshal([]byte(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("failed to parse yaml: %v", err)
}
if diff := cmp.Diff(tc.want, got.Sources); diff != "" {
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Errorf("unexpected config diff (-want +got):\n%s", diff)
}
})
}
}
func TestFailParseFromYaml(t *testing.T) {
tcs := []struct {
desc string
in string
err string
}{
{
desc: "extra field",
in: `
kind: sources
name: my-es-instance
type: elasticsearch
addresses:
- http://localhost:9200
foo: bar
`,
err: "error unmarshaling sources: unable to parse source \"my-es-instance\" as \"elasticsearch\": [3:1] unknown field \"foo\"\n 1 | addresses:\n 2 | - http://localhost:9200\n> 3 | foo: bar\n ^\n 4 | name: my-es-instance\n 5 | type: elasticsearch",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}
errStr := err.Error()
if errStr != tc.err {
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
}
})
}
}
func TestTool_esqlToMap(t1 *testing.T) {
tests := []struct {
name string

View File

@@ -27,13 +27,13 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
)
const SourceKind string = "firebird"
const SourceType string = "firebird"
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -47,7 +47,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
@@ -55,8 +55,8 @@ type Config struct {
Database string `yaml:"database" validate:"required"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
func (r Config) SourceConfigType() string {
return SourceType
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -84,8 +84,8 @@ type Source struct {
Db *sql.DB
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -144,7 +144,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
}
func initFirebirdConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string) (*sql.DB, error) {
_, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
_, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
defer span.End()
// urlExample := "user:password@host:port/path/to/database.fdb"

View File

@@ -15,11 +15,12 @@
package firebird_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/firebird"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -33,19 +34,19 @@ func TestParseFromYamlFirebird(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-fdb-instance:
kind: firebird
host: my-host
port: my-port
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-fdb-instance
type: firebird
host: my-host
port: my-port
database: my_db
user: my_user
password: my_pass
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-fdb-instance": firebird.Config{
Name: "my-fdb-instance",
Kind: firebird.SourceKind,
Type: firebird.SourceType,
Host: "my-host",
Port: "my-port",
Database: "my_db",
@@ -57,16 +58,12 @@ func TestParseFromYamlFirebird(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
@@ -82,39 +79,35 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
sources:
my-fdb-instance:
kind: firebird
host: my-host
port: my-port
database: my_db
user: my_user
password: my_pass
foo: bar
kind: sources
name: my-fdb-instance
type: firebird
host: my-host
port: my-port
database: my_db
user: my_user
password: my_pass
foo: bar
`,
err: "unable to parse source \"my-fdb-instance\" as \"firebird\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: my-host\n 4 | kind: firebird\n 5 | password: my_pass\n 6 | ",
err: "error unmarshaling sources: unable to parse source \"my-fdb-instance\" as \"firebird\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: my-host\n 4 | name: my-fdb-instance\n 5 | password: my_pass\n 6 | ",
},
{
desc: "missing required field",
in: `
sources:
my-fdb-instance:
kind: firebird
host: my-host
port: my-port
database: my_db
user: my_user
kind: sources
name: my-fdb-instance
type: firebird
host: my-host
port: my-port
database: my_db
user: my_user
`,
err: "unable to parse source \"my-fdb-instance\" as \"firebird\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
err: "error unmarshaling sources: unable to parse source \"my-fdb-instance\" as \"firebird\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -31,14 +31,14 @@ import (
"google.golang.org/genproto/googleapis/type/latlng"
)
const SourceKind string = "firestore"
const SourceType string = "firestore"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -53,14 +53,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
// Firestore configs
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Project string `yaml:"project" validate:"required"`
Database string `yaml:"database"` // Optional, defaults to "(default)"
}
func (r Config) SourceConfigKind() string {
// Returns Firestore source kind
return SourceKind
func (r Config) SourceConfigType() string {
// Returns Firestore source type
return SourceType
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -92,9 +92,9 @@ type Source struct {
RulesClient *firebaserules.Service
}
func (s *Source) SourceKind() string {
// Returns Firestore source kind
return SourceKind
func (s *Source) SourceType() string {
// Returns Firestore source type
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -594,7 +594,7 @@ func initFirestoreConnection(
project string,
database string,
) (*firestore.Client, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
defer span.End()
userAgent, err := util.UserAgentFromContext(ctx)

View File

@@ -15,12 +15,13 @@
package firestore_test
import (
"context"
"testing"
"time"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/firestore"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -34,15 +35,15 @@ func TestParseFromYamlFirestore(t *testing.T) {
{
desc: "basic example with default database",
in: `
sources:
my-firestore:
kind: firestore
project: my-project
kind: sources
name: my-firestore
type: firestore
project: my-project
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-firestore": firestore.Config{
Name: "my-firestore",
Kind: firestore.SourceKind,
Type: firestore.SourceType,
Project: "my-project",
Database: "",
},
@@ -51,16 +52,16 @@ func TestParseFromYamlFirestore(t *testing.T) {
{
desc: "with custom database",
in: `
sources:
my-firestore:
kind: firestore
project: my-project
database: my-database
kind: sources
name: my-firestore
type: firestore
project: my-project
database: my-database
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-firestore": firestore.Config{
Name: "my-firestore",
Kind: firestore.SourceKind,
Type: firestore.SourceType,
Project: "my-project",
Database: "my-database",
},
@@ -69,22 +70,18 @@ func TestParseFromYamlFirestore(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
}
func TestFailParseFromYamlFirestore(t *testing.T) {
func TestFailParseFromYaml(t *testing.T) {
tcs := []struct {
desc string
in string
@@ -93,32 +90,27 @@ func TestFailParseFromYamlFirestore(t *testing.T) {
{
desc: "extra field",
in: `
sources:
my-firestore:
kind: firestore
project: my-project
foo: bar
kind: sources
name: my-firestore
type: firestore
project: my-project
foo: bar
`,
err: "unable to parse source \"my-firestore\" as \"firestore\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | kind: firestore\n 3 | project: my-project",
err: "error unmarshaling sources: unable to parse source \"my-firestore\" as \"firestore\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | name: my-firestore\n 3 | project: my-project\n 4 | type: firestore",
},
{
desc: "missing required field",
in: `
sources:
my-firestore:
kind: firestore
database: my-database
kind: sources
name: my-firestore
type: firestore
`,
err: "unable to parse source \"my-firestore\" as \"firestore\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
err: "error unmarshaling sources: unable to parse source \"my-firestore\" as \"firestore\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -29,14 +29,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "http"
const SourceType string = "http"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -50,7 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
BaseURL string `yaml:"baseUrl"`
Timeout string `yaml:"timeout"`
DefaultHeaders map[string]string `yaml:"headers"`
@@ -58,8 +58,8 @@ type Config struct {
DisableSslVerification bool `yaml:"disableSslVerification"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
func (r Config) SourceConfigType() string {
return SourceType
}
// Initialize initializes an HTTP Source instance.
@@ -122,8 +122,8 @@ type Source struct {
client *http.Client
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -15,9 +15,9 @@
package http_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -34,15 +34,15 @@ func TestParseFromYamlHttp(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-http-instance:
kind: http
baseUrl: http://test_server/
kind: sources
name: my-http-instance
type: http
baseUrl: http://test_server/
`,
want: map[string]sources.SourceConfig{
"my-http-instance": http.Config{
Name: "my-http-instance",
Kind: http.SourceKind,
Type: http.SourceType,
BaseURL: "http://test_server/",
Timeout: "30s",
DisableSslVerification: false,
@@ -52,23 +52,23 @@ func TestParseFromYamlHttp(t *testing.T) {
{
desc: "advanced example",
in: `
sources:
my-http-instance:
kind: http
baseUrl: http://test_server/
timeout: 10s
headers:
Authorization: test_header
Custom-Header: custom
queryParams:
api-key: test_api_key
param: param-value
disableSslVerification: true
kind: sources
name: my-http-instance
type: http
baseUrl: http://test_server/
timeout: 10s
headers:
Authorization: test_header
Custom-Header: custom
queryParams:
api-key: test_api_key
param: param-value
disableSslVerification: true
`,
want: map[string]sources.SourceConfig{
"my-http-instance": http.Config{
Name: "my-http-instance",
Kind: http.SourceKind,
Type: http.SourceType,
BaseURL: "http://test_server/",
Timeout: "10s",
DefaultHeaders: map[string]string{"Authorization": "test_header", "Custom-Header": "custom"},
@@ -80,16 +80,12 @@ func TestParseFromYamlHttp(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
@@ -104,36 +100,32 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
sources:
my-http-instance:
kind: http
baseUrl: http://test_server/
timeout: 10s
headers:
Authorization: test_header
queryParams:
api-key: test_api_key
project: test-project
kind: sources
name: my-http-instance
type: http
baseUrl: http://test_server/
timeout: 10s
headers:
Authorization: test_header
queryParams:
api-key: test_api_key
project: test-project
`,
err: "unable to parse source \"my-http-instance\" as \"http\": [5:1] unknown field \"project\"\n 2 | headers:\n 3 | Authorization: test_header\n 4 | kind: http\n> 5 | project: test-project\n ^\n 6 | queryParams:\n 7 | api-key: test_api_key\n 8 | timeout: 10s",
err: "error unmarshaling sources: unable to parse source \"my-http-instance\" as \"http\": [5:1] unknown field \"project\"\n 2 | headers:\n 3 | Authorization: test_header\n 4 | name: my-http-instance\n> 5 | project: test-project\n ^\n 6 | queryParams:\n 7 | api-key: test_api_key\n 8 | timeout: 10s\n 9 | ",
},
{
desc: "missing required field",
in: `
sources:
my-http-instance:
baseUrl: http://test_server/
kind: sources
name: my-http-instance
baseUrl: http://test_server/
`,
err: "missing 'kind' field for source \"my-http-instance\"",
err: "error unmarshaling sources: missing 'type' field or it is not a string",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -33,14 +33,14 @@ import (
v4 "github.com/looker-open-source/sdk-codegen/go/sdk/v4"
)
const SourceKind string = "looker"
const SourceType string = "looker"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -64,7 +64,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
BaseURL string `yaml:"base_url" validate:"required"`
ClientId string `yaml:"client_id"`
ClientSecret string `yaml:"client_secret"`
@@ -79,8 +79,8 @@ type Config struct {
SessionLength int64 `yaml:"sessionLength"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
func (r Config) SourceConfigType() string {
return SourceType
}
// Initialize initializes a Looker Source instance.
@@ -154,8 +154,8 @@ type Source struct {
AuthTokenHeaderName string
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {

View File

@@ -15,9 +15,9 @@
package looker_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -34,17 +34,17 @@ func TestParseFromYamlLooker(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-looker-instance:
kind: looker
base_url: http://example.looker.com/
client_id: jasdl;k;tjl
client_secret: sdakl;jgflkasdfkfg
kind: sources
name: my-looker-instance
type: looker
base_url: http://example.looker.com/
client_id: jasdl;k;tjl
client_secret: sdakl;jgflkasdfkfg
`,
want: map[string]sources.SourceConfig{
"my-looker-instance": looker.Config{
Name: "my-looker-instance",
Kind: looker.SourceKind,
Type: looker.SourceType,
BaseURL: "http://example.looker.com/",
ClientId: "jasdl;k;tjl",
ClientSecret: "sdakl;jgflkasdfkfg",
@@ -62,22 +62,18 @@ func TestParseFromYamlLooker(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
}
func TestFailParseFromYamlLooker(t *testing.T) {
func TestFailParseFromYaml(t *testing.T) {
tcs := []struct {
desc string
in string
@@ -86,34 +82,30 @@ func TestFailParseFromYamlLooker(t *testing.T) {
{
desc: "extra field",
in: `
sources:
my-looker-instance:
kind: looker
base_url: http://example.looker.com/
client_id: jasdl;k;tjl
client_secret: sdakl;jgflkasdfkfg
schema: test-schema
kind: sources
name: my-looker-instance
type: looker
base_url: http://example.looker.com/
client_id: jasdl;k;tjl
client_secret: sdakl;jgflkasdfkfg
schema: test-schema
`,
err: "unable to parse source \"my-looker-instance\" as \"looker\": [5:1] unknown field \"schema\"\n 2 | client_id: jasdl;k;tjl\n 3 | client_secret: sdakl;jgflkasdfkfg\n 4 | kind: looker\n> 5 | schema: test-schema\n ^\n",
err: "error unmarshaling sources: unable to parse source \"my-looker-instance\" as \"looker\": [5:1] unknown field \"schema\"\n 2 | client_id: jasdl;k;tjl\n 3 | client_secret: sdakl;jgflkasdfkfg\n 4 | name: my-looker-instance\n> 5 | schema: test-schema\n ^\n 6 | type: looker",
},
{
desc: "missing required field",
in: `
sources:
my-looker-instance:
kind: looker
client_id: jasdl;k;tjl
kind: sources
name: my-looker-instance
type: looker
client_id: jasdl;k;tjl
`,
err: "unable to parse source \"my-looker-instance\" as \"looker\": Key: 'Config.BaseURL' Error:Field validation for 'BaseURL' failed on the 'required' tag",
err: "error unmarshaling sources: unable to parse source \"my-looker-instance\" as \"looker\": Key: 'Config.BaseURL' Error:Field validation for 'BaseURL' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -27,14 +27,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "mindsdb"
const SourceType string = "mindsdb"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -48,7 +48,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
@@ -57,8 +57,8 @@ type Config struct {
QueryTimeout string `yaml:"queryTimeout"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
func (r Config) SourceConfigType() string {
return SourceType
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -86,8 +86,8 @@ type Source struct {
Pool *sql.DB
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -159,7 +159,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
func initMindsDBConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
defer span.End()
// Configure the driver to connect to the database

View File

@@ -15,11 +15,12 @@
package mindsdb_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/mindsdb"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -33,19 +34,19 @@ func TestParseFromYamlMindsDB(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-mindsdb-instance:
kind: mindsdb
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-mindsdb-instance
type: mindsdb
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-mindsdb-instance": mindsdb.Config{
Name: "my-mindsdb-instance",
Kind: mindsdb.SourceKind,
Type: mindsdb.SourceType,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -57,20 +58,20 @@ func TestParseFromYamlMindsDB(t *testing.T) {
{
desc: "with query timeout",
in: `
sources:
my-mindsdb-instance:
kind: mindsdb
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
queryTimeout: 45s
kind: sources
name: my-mindsdb-instance
type: mindsdb
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
queryTimeout: 45s
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-mindsdb-instance": mindsdb.Config{
Name: "my-mindsdb-instance",
Kind: mindsdb.SourceKind,
Type: mindsdb.SourceType,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -83,16 +84,12 @@ func TestParseFromYamlMindsDB(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
@@ -108,39 +105,35 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
sources:
my-mindsdb-instance:
kind: mindsdb
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
foo: bar
kind: sources
name: my-mindsdb-instance
type: mindsdb
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
foo: bar
`,
err: "unable to parse source \"my-mindsdb-instance\" as \"mindsdb\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | kind: mindsdb\n 5 | password: my_pass\n 6 | ",
err: "error unmarshaling sources: unable to parse source \"my-mindsdb-instance\" as \"mindsdb\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | name: my-mindsdb-instance\n 5 | password: my_pass\n 6 | ",
},
{
desc: "missing required field",
in: `
sources:
my-mindsdb-instance:
kind: mindsdb
port: my-port
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-mindsdb-instance
type: mindsdb
port: my-port
database: my_db
user: my_user
password: my_pass
`,
err: "unable to parse source \"my-mindsdb-instance\" as \"mindsdb\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
err: "error unmarshaling sources: unable to parse source \"my-mindsdb-instance\" as \"mindsdb\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -29,14 +29,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "mongodb"
const SourceType string = "mongodb"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -50,12 +50,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Uri string `yaml:"uri" validate:"required"` // MongoDB Atlas connection URI
}
func (r Config) SourceConfigKind() string {
return SourceKind
func (r Config) SourceConfigType() string {
return SourceType
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -84,8 +84,8 @@ type Source struct {
Client *mongo.Client
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -293,7 +293,7 @@ func (s *Source) DeleteOne(ctx context.Context, filterString, database, collecti
func initMongoDBClient(ctx context.Context, tracer trace.Tracer, name, uri string) (*mongo.Client, error) {
// Start a tracing span
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
defer span.End()
userAgent, err := util.UserAgentFromContext(ctx)

View File

@@ -15,11 +15,12 @@
package mongodb_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/mongodb"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -33,15 +34,15 @@ func TestParseFromYamlMongoDB(t *testing.T) {
{
desc: "basic example",
in: `
sources:
mongo-db:
kind: "mongodb"
uri: "mongodb+srv://username:password@host/dbname"
kind: sources
name: mongo-db
type: "mongodb"
uri: "mongodb+srv://username:password@host/dbname"
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"mongo-db": mongodb.Config{
Name: "mongo-db",
Kind: mongodb.SourceKind,
Type: mongodb.SourceType,
Uri: "mongodb+srv://username:password@host/dbname",
},
},
@@ -49,16 +50,12 @@ func TestParseFromYamlMongoDB(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
@@ -74,31 +71,27 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
sources:
mongo-db:
kind: mongodb
uri: "mongodb+srv://username:password@host/dbname"
foo: bar
kind: sources
name: mongo-db
type: mongodb
uri: "mongodb+srv://username:password@host/dbname"
foo: bar
`,
err: "unable to parse source \"mongo-db\" as \"mongodb\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | kind: mongodb\n 3 | uri: mongodb+srv://username:password@host/dbname",
err: "error unmarshaling sources: unable to parse source \"mongo-db\" as \"mongodb\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | name: mongo-db\n 3 | type: mongodb\n 4 | uri: mongodb+srv://username:password@host/dbname",
},
{
desc: "missing required field",
in: `
sources:
mongo-db:
kind: mongodb
kind: sources
name: mongo-db
type: mongodb
`,
err: "unable to parse source \"mongo-db\" as \"mongodb\": Key: 'Config.Uri' Error:Field validation for 'Uri' failed on the 'required' tag",
err: "error unmarshaling sources: unable to parse source \"mongo-db\" as \"mongodb\": Key: 'Config.Uri' Error:Field validation for 'Uri' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -28,14 +28,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "mssql"
const SourceType string = "mssql"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -50,7 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
// Cloud SQL MSSQL configs
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
@@ -59,9 +59,9 @@ type Config struct {
Encrypt string `yaml:"encrypt"`
}
func (r Config) SourceConfigKind() string {
// Returns Cloud SQL MSSQL source kind
return SourceKind
func (r Config) SourceConfigType() string {
// Returns Cloud SQL MSSQL source type
return SourceType
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -91,9 +91,9 @@ type Source struct {
Db *sql.DB
}
func (s *Source) SourceKind() string {
// Returns Cloud SQL MSSQL source kind
return SourceKind
func (s *Source) SourceType() string {
// Returns Cloud SQL MSSQL source type
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -156,7 +156,7 @@ func initMssqlConnection(
error,
) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
defer span.End()
userAgent, err := util.UserAgentFromContext(ctx)

View File

@@ -15,11 +15,12 @@
package mssql_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/mssql"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -33,19 +34,19 @@ func TestParseFromYamlMssql(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-mssql-instance:
kind: mssql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-mssql-instance
type: mssql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-mssql-instance": mssql.Config{
Name: "my-mssql-instance",
Kind: mssql.SourceKind,
Type: mssql.SourceType,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -57,20 +58,20 @@ func TestParseFromYamlMssql(t *testing.T) {
{
desc: "with encrypt field",
in: `
sources:
my-mssql-instance:
kind: mssql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
encrypt: strict
kind: sources
name: my-mssql-instance
type: mssql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
encrypt: strict
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-mssql-instance": mssql.Config{
Name: "my-mssql-instance",
Kind: mssql.SourceKind,
Type: mssql.SourceType,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -83,16 +84,12 @@ func TestParseFromYamlMssql(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect psarse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect psarse: want %v, got %v", tc.want, got)
}
})
}
@@ -107,39 +104,35 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
sources:
my-mssql-instance:
kind: mssql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
foo: bar
kind: sources
name: my-mssql-instance
type: mssql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
foo: bar
`,
err: "unable to parse source \"my-mssql-instance\" as \"mssql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | kind: mssql\n 5 | password: my_pass\n 6 | ",
err: "error unmarshaling sources: unable to parse source \"my-mssql-instance\" as \"mssql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | name: my-mssql-instance\n 5 | password: my_pass\n 6 | ",
},
{
desc: "missing required field",
in: `
sources:
my-mssql-instance:
kind: mssql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
kind: sources
name: my-mssql-instance
type: mssql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
`,
err: "unable to parse source \"my-mssql-instance\" as \"mssql\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
err: "error unmarshaling sources: unable to parse source \"my-mssql-instance\" as \"mssql\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -30,14 +30,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "mysql"
const SourceType string = "mysql"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -51,7 +51,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
@@ -61,8 +61,8 @@ type Config struct {
QueryParams map[string]string `yaml:"queryParams"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
func (r Config) SourceConfigType() string {
return SourceType
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -90,8 +90,8 @@ type Source struct {
Pool *sql.DB
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -158,7 +158,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
func initMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string, queryParams map[string]string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
defer span.End()
// Build query parameters via url.Values for deterministic order and proper escaping.

View File

@@ -19,12 +19,12 @@ import (
"strings"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"go.opentelemetry.io/otel/trace/noop"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/mysql"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -38,19 +38,19 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-mysql-instance:
kind: mysql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-mysql-instance
type: mysql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-mysql-instance": mysql.Config{
Name: "my-mysql-instance",
Kind: mysql.SourceKind,
Type: mysql.SourceType,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -62,20 +62,20 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
{
desc: "with query timeout",
in: `
sources:
my-mysql-instance:
kind: mysql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
queryTimeout: 45s
kind: sources
name: my-mysql-instance
type: mysql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
queryTimeout: 45s
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-mysql-instance": mysql.Config{
Name: "my-mysql-instance",
Kind: mysql.SourceKind,
Type: mysql.SourceType,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -88,22 +88,22 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
{
desc: "with query params",
in: `
sources:
my-mysql-instance:
kind: mysql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
queryParams:
tls: preferred
charset: utf8mb4
kind: sources
name: my-mysql-instance
type: mysql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
queryParams:
tls: preferred
charset: utf8mb4
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-mysql-instance": mysql.Config{
Name: "my-mysql-instance",
Kind: mysql.SourceKind,
Type: mysql.SourceType,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
@@ -120,15 +120,11 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Sources, cmpopts.EquateEmpty()); diff != "" {
if diff := cmp.Diff(tc.want, got, cmpopts.EquateEmpty()); diff != "" {
t.Fatalf("mismatch (-want +got):\n%s", diff)
}
})
@@ -145,55 +141,51 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
sources:
my-mysql-instance:
kind: mysql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
foo: bar
kind: sources
name: my-mysql-instance
type: mysql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
foo: bar
`,
err: "unknown field \"foo\"",
err: "error unmarshaling sources: unable to parse source \"my-mysql-instance\" as \"mysql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | name: my-mysql-instance\n 5 | password: my_pass\n 6 | ",
},
{
desc: "missing required field",
in: `
sources:
my-mysql-instance:
kind: mysql
port: my-port
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-mysql-instance
type: mysql
port: my-port
database: my_db
user: my_user
password: my_pass
`,
err: "Field validation for 'Host' failed",
err: "error unmarshaling sources: unable to parse source \"my-mysql-instance\" as \"mysql\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
},
{
desc: "invalid query params type",
in: `
sources:
my-mysql-instance:
kind: mysql
host: 0.0.0.0
port: 3306
database: my_db
user: my_user
password: my_pass
queryParams: not-a-map
kind: sources
name: my-mysql-instance
type: mysql
host: 0.0.0.0
port: 3306
database: my_db
user: my_user
password: my_pass
queryParams: not-a-map
`,
err: "string was used where mapping is expected",
err: "error unmarshaling sources: unable to parse source \"my-mysql-instance\" as \"mysql\": [6:14] string was used where mapping is expected\n 3 | name: my-mysql-instance\n 4 | password: my_pass\n 5 | port: 3306\n> 6 | queryParams: not-a-map\n ^\n 7 | type: mysql\n 8 | user: my_user",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}
@@ -211,7 +203,7 @@ func TestFailInitialization(t *testing.T) {
cfg := mysql.Config{
Name: "instance",
Kind: "mysql",
Type: "mysql",
Host: "localhost",
Port: "3306",
Database: "db",

View File

@@ -29,7 +29,7 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "neo4j"
const SourceType string = "neo4j"
var sourceClassifier *classifier.QueryClassifier = classifier.NewQueryClassifier()
@@ -37,8 +37,8 @@ var sourceClassifier *classifier.QueryClassifier = classifier.NewQueryClassifier
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -52,15 +52,15 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Uri string `yaml:"uri" validate:"required"`
User string `yaml:"user" validate:"required"`
Password string `yaml:"password" validate:"required"`
Database string `yaml:"database" validate:"required"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
func (r Config) SourceConfigType() string {
return SourceType
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -91,8 +91,8 @@ type Source struct {
Driver neo4j.DriverWithContext
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -182,7 +182,7 @@ func addPlanChildren(p neo4j.Plan) []map[string]any {
func initNeo4jDriver(ctx context.Context, tracer trace.Tracer, uri, user, password, name string) (neo4j.DriverWithContext, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
defer span.End()
auth := neo4j.BasicAuth(user, password, "")

View File

@@ -15,11 +15,12 @@
package neo4j_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/neo4j"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -33,18 +34,18 @@ func TestParseFromYamlNeo4j(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-neo4j-instance:
kind: neo4j
uri: neo4j+s://my-host:7687
database: my_db
user: my_user
password: my_pass
kind: sources
name: my-neo4j-instance
type: neo4j
uri: neo4j+s://my-host:7687
database: my_db
user: my_user
password: my_pass
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-neo4j-instance": neo4j.Config{
Name: "my-neo4j-instance",
Kind: neo4j.SourceKind,
Type: neo4j.SourceType,
Uri: "neo4j+s://my-host:7687",
Database: "my_db",
User: "my_user",
@@ -55,16 +56,12 @@ func TestParseFromYamlNeo4j(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
@@ -80,37 +77,33 @@ func TestFailParseFromYaml(t *testing.T) {
{
desc: "extra field",
in: `
sources:
my-neo4j-instance:
kind: neo4j
uri: neo4j+s://my-host:7687
database: my_db
user: my_user
password: my_pass
foo: bar
kind: sources
name: my-neo4j-instance
type: neo4j
uri: neo4j+s://my-host:7687
database: my_db
user: my_user
password: my_pass
foo: bar
`,
err: "unable to parse source \"my-neo4j-instance\" as \"neo4j\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | kind: neo4j\n 4 | password: my_pass\n 5 | uri: neo4j+s://my-host:7687\n 6 | ",
err: "error unmarshaling sources: unable to parse source \"my-neo4j-instance\" as \"neo4j\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | name: my-neo4j-instance\n 4 | password: my_pass\n 5 | type: neo4j\n 6 | ",
},
{
desc: "missing required field",
in: `
sources:
my-neo4j-instance:
kind: neo4j
uri: neo4j+s://my-host:7687
database: my_db
user: my_user
kind: sources
name: my-neo4j-instance
type: neo4j
uri: neo4j+s://my-host:7687
database: my_db
user: my_user
`,
err: "unable to parse source \"my-neo4j-instance\" as \"neo4j\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
err: "error unmarshaling sources: unable to parse source \"my-neo4j-instance\" as \"neo4j\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -27,14 +27,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "oceanbase"
const SourceType string = "oceanbase"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -48,7 +48,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
@@ -57,8 +57,8 @@ type Config struct {
QueryTimeout string `yaml:"queryTimeout"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
func (r Config) SourceConfigType() string {
return SourceType
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -86,8 +86,8 @@ type Source struct {
Pool *sql.DB
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -153,7 +153,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
}
func initOceanBaseConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
_, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
_, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
defer span.End()
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", user, pass, host, port, dbname)

View File

@@ -15,11 +15,12 @@
package oceanbase_test
import (
"context"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/oceanbase"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -34,19 +35,19 @@ func TestParseFromYamlOceanBase(t *testing.T) {
{
desc: "basic example",
in: `
sources:
my-oceanbase-instance:
kind: oceanbase
host: 0.0.0.0
port: 2881
database: ob_db
user: ob_user
password: ob_pass
kind: sources
name: my-oceanbase-instance
type: oceanbase
host: 0.0.0.0
port: 2881
database: ob_db
user: ob_user
password: ob_pass
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-oceanbase-instance": oceanbase.Config{
Name: "my-oceanbase-instance",
Kind: oceanbase.SourceKind,
Type: oceanbase.SourceType,
Host: "0.0.0.0",
Port: "2881",
Database: "ob_db",
@@ -58,20 +59,20 @@ func TestParseFromYamlOceanBase(t *testing.T) {
{
desc: "with query timeout",
in: `
sources:
my-oceanbase-instance:
kind: oceanbase
host: 0.0.0.0
port: 2881
database: ob_db
user: ob_user
password: ob_pass
queryTimeout: 30s
kind: sources
name: my-oceanbase-instance
type: oceanbase
host: 0.0.0.0
port: 2881
database: ob_db
user: ob_user
password: ob_pass
queryTimeout: 30s
`,
want: server.SourceConfigs{
want: map[string]sources.SourceConfig{
"my-oceanbase-instance": oceanbase.Config{
Name: "my-oceanbase-instance",
Kind: oceanbase.SourceKind,
Type: oceanbase.SourceType,
Host: "0.0.0.0",
Port: "2881",
Database: "ob_db",
@@ -84,16 +85,12 @@ func TestParseFromYamlOceanBase(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
}
})
}
@@ -109,39 +106,35 @@ func TestFailParseFromYamlOceanBase(t *testing.T) {
{
desc: "extra field",
in: `
sources:
my-oceanbase-instance:
kind: oceanbase
host: 0.0.0.0
port: 2881
database: ob_db
user: ob_user
password: ob_pass
foo: bar
kind: sources
name: my-oceanbase-instance
type: oceanbase
host: 0.0.0.0
port: 2881
database: ob_db
user: ob_user
password: ob_pass
foo: bar
`,
err: "unable to parse source \"my-oceanbase-instance\" as \"oceanbase\": [2:1] unknown field \"foo\"\n 1 | database: ob_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | kind: oceanbase\n 5 | password: ob_pass\n 6 | ",
err: "error unmarshaling sources: unable to parse source \"my-oceanbase-instance\" as \"oceanbase\": [2:1] unknown field \"foo\"\n 1 | database: ob_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | name: my-oceanbase-instance\n 5 | password: ob_pass\n 6 | ",
},
{
desc: "missing required field",
in: `
sources:
my-oceanbase-instance:
kind: oceanbase
port: 2881
database: ob_db
user: ob_user
password: ob_pass
kind: sources
name: my-oceanbase-instance
type: oceanbase
port: 2881
database: ob_db
user: ob_user
password: ob_pass
`,
err: "unable to parse source \"my-oceanbase-instance\" as \"oceanbase\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
err: "error unmarshaling sources: unable to parse source \"my-oceanbase-instance\" as \"oceanbase\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

View File

@@ -18,14 +18,14 @@ import (
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "oracle"
const SourceType string = "oracle"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
if !sources.Register(SourceType, newConfig) {
panic(fmt.Sprintf("source type %q already registered", SourceType))
}
}
@@ -45,7 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
ConnectionString string `yaml:"connectionString,omitempty"`
TnsAlias string `yaml:"tnsAlias,omitempty"`
TnsAdmin string `yaml:"tnsAdmin,omitempty"`
@@ -95,8 +95,8 @@ func (c Config) validate() error {
return nil
}
func (r Config) SourceConfigKind() string {
return SourceKind
func (r Config) SourceConfigType() string {
return SourceType
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
@@ -124,8 +124,8 @@ type Source struct {
DB *sql.DB
}
func (s *Source) SourceKind() string {
return SourceKind
func (s *Source) SourceType() string {
return SourceType
}
func (s *Source) ToConfig() sources.SourceConfig {
@@ -239,7 +239,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an
func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Config) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, config.Name)
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, config.Name)
defer span.End()
logger, err := util.LoggerFromContext(ctx)

View File

@@ -3,12 +3,13 @@
package oracle_test
import (
"context"
"strings"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/oracle"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
@@ -22,18 +23,18 @@ func TestParseFromYamlOracle(t *testing.T) {
{
desc: "connection string and useOCI=true",
in: `
sources:
my-oracle-cs:
kind: oracle
connectionString: "my-host:1521/XEPDB1"
user: my_user
password: my_pass
useOCI: true
`,
want: server.SourceConfigs{
kind: sources
name: my-oracle-cs
type: oracle
connectionString: "my-host:1521/XEPDB1"
user: my_user
password: my_pass
useOCI: true
`,
want: map[string]sources.SourceConfig{
"my-oracle-cs": oracle.Config{
Name: "my-oracle-cs",
Kind: oracle.SourceKind,
Type: oracle.SourceType,
ConnectionString: "my-host:1521/XEPDB1",
User: "my_user",
Password: "my_pass",
@@ -44,19 +45,19 @@ func TestParseFromYamlOracle(t *testing.T) {
{
desc: "host/port/serviceName and default useOCI=false",
in: `
sources:
my-oracle-host:
kind: oracle
host: my-host
port: 1521
serviceName: ORCLPDB
user: my_user
password: my_pass
`,
want: server.SourceConfigs{
kind: sources
name: my-oracle-host
type: oracle
host: my-host
port: 1521
serviceName: ORCLPDB
user: my_user
password: my_pass
`,
want: map[string]sources.SourceConfig{
"my-oracle-host": oracle.Config{
Name: "my-oracle-host",
Kind: oracle.SourceKind,
Type: oracle.SourceType,
Host: "my-host",
Port: 1521,
ServiceName: "ORCLPDB",
@@ -69,19 +70,19 @@ func TestParseFromYamlOracle(t *testing.T) {
{
desc: "tnsAlias and TnsAdmin specified with explicit useOCI=true",
in: `
sources:
my-oracle-tns-oci:
kind: oracle
tnsAlias: FINANCE_DB
tnsAdmin: /opt/oracle/network/admin
user: my_user
password: my_pass
useOCI: true
`,
want: server.SourceConfigs{
kind: sources
name: my-oracle-tns-oci
type: oracle
tnsAlias: FINANCE_DB
tnsAdmin: /opt/oracle/network/admin
user: my_user
password: my_pass
useOCI: true
`,
want: map[string]sources.SourceConfig{
"my-oracle-tns-oci": oracle.Config{
Name: "my-oracle-tns-oci",
Kind: oracle.SourceKind,
Type: oracle.SourceType,
TnsAlias: "FINANCE_DB",
TnsAdmin: "/opt/oracle/network/admin",
User: "my_user",
@@ -93,22 +94,18 @@ func TestParseFromYamlOracle(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse:\nwant: %v\ngot: %v\ndiff: %s", tc.want, got.Sources, cmp.Diff(tc.want, got.Sources))
if !cmp.Equal(tc.want, got) {
t.Fatalf("incorrect parse:\nwant: %v\ngot: %v\ndiff: %s", tc.want, got, cmp.Diff(tc.want, got))
}
})
}
}
func TestFailParseFromYamlOracle(t *testing.T) {
func TestFailParseFromYaml(t *testing.T) {
tcs := []struct {
desc string
in string
@@ -117,76 +114,72 @@ func TestFailParseFromYamlOracle(t *testing.T) {
{
desc: "extra field",
in: `
sources:
my-oracle-instance:
kind: oracle
host: my-host
serviceName: ORCL
user: my_user
password: my_pass
extraField: value
kind: sources
name: my-oracle-instance
type: oracle
host: my-host
serviceName: ORCL
user: my_user
password: my_pass
extraField: value
`,
err: "unable to parse source \"my-oracle-instance\" as \"oracle\": [1:1] unknown field \"extraField\"\n> 1 | extraField: value\n ^\n 2 | host: my-host\n 3 | kind: oracle\n 4 | password: my_pass\n 5 | ",
err: "error unmarshaling sources: unable to parse source \"my-oracle-instance\" as \"oracle\": [1:1] unknown field \"extraField\"\n> 1 | extraField: value\n ^\n 2 | host: my-host\n 3 | name: my-oracle-instance\n 4 | password: my_pass\n 5 | ",
},
{
desc: "missing required password field",
in: `
sources:
my-oracle-instance:
kind: oracle
host: my-host
serviceName: ORCL
user: my_user
`,
err: "unable to parse source \"my-oracle-instance\" as \"oracle\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
kind: sources
name: my-oracle-instance
type: oracle
host: my-host
serviceName: ORCL
user: my_user
`,
err: "error unmarshaling sources: unable to parse source \"my-oracle-instance\" as \"oracle\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
},
{
desc: "missing connection method fields (validate fails)",
in: `
sources:
my-oracle-instance:
kind: oracle
user: my_user
password: my_pass
`,
err: "unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: must provide one of: 'tns_alias', 'connection_string', or both 'host' and 'service_name'",
kind: sources
name: my-oracle-instance
type: oracle
user: my_user
password: my_pass
`,
err: "error unmarshaling sources: unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: must provide one of: 'tns_alias', 'connection_string', or both 'host' and 'service_name'",
},
{
desc: "multiple connection methods provided (validate fails)",
in: `
sources:
my-oracle-instance:
kind: oracle
host: my-host
serviceName: ORCL
connectionString: "my-host:1521/XEPDB1"
user: my_user
password: my_pass
`,
err: "unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: provide only one connection method: 'tns_alias', 'connection_string', or 'host'+'service_name'",
kind: sources
name: my-oracle-instance
type: oracle
host: my-host
serviceName: ORCL
connectionString: "my-host:1521/XEPDB1"
user: my_user
password: my_pass
`,
err: "error unmarshaling sources: unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: provide only one connection method: 'tns_alias', 'connection_string', or 'host'+'service_name'",
},
{
desc: "fail on tnsAdmin with useOCI=false",
in: `
sources:
my-oracle-fail:
kind: oracle
tnsAlias: FINANCE_DB
tnsAdmin: /opt/oracle/network/admin
user: my_user
password: my_pass
useOCI: false
kind: sources
name: my-oracle-fail
type: oracle
tnsAlias: FINANCE_DB
tnsAdmin: /opt/oracle/network/admin
user: my_user
password: my_pass
useOCI: false
`,
err: "unable to parse source \"my-oracle-fail\" as \"oracle\": invalid Oracle configuration: `tnsAdmin` can only be used when `UseOCI` is true, or use `walletLocation` instead",
err: "error unmarshaling sources: unable to parse source \"my-oracle-fail\" as \"oracle\": invalid Oracle configuration: `tnsAdmin` can only be used when `UseOCI` is true, or use `walletLocation` instead",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
}

Some files were not shown because too many files have changed in this diff Show More