mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-22 05:48:08 -05:00
Compare commits
3 Commits
config-upd
...
config-unm
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c29355ff82 | ||
|
|
70f5550910 | ||
|
|
348c9fde08 |
135
cmd/root.go
135
cmd/root.go
@@ -15,6 +15,7 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
@@ -394,7 +395,6 @@ func NewCommand(opts ...Option) *Command {
|
||||
|
||||
type ToolsFile struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
AuthSources server.AuthServiceConfigs `yaml:"authSources"` // Deprecated: Kept for compatibility.
|
||||
AuthServices server.AuthServiceConfigs `yaml:"authServices"`
|
||||
EmbeddingModels server.EmbeddingModelConfigs `yaml:"embeddingModels"`
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
@@ -425,6 +425,106 @@ func parseEnv(input string) (string, error) {
|
||||
return output, err
|
||||
}
|
||||
|
||||
func convertToolsFile(ctx context.Context, raw []byte) ([]byte, error) {
|
||||
var input yaml.MapSlice
|
||||
decoder := yaml.NewDecoder(bytes.NewReader(raw), yaml.UseOrderedMap())
|
||||
if err := decoder.Decode(&input); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert raw MapSlice to a helper map for quick lookup
|
||||
// while keeping the values as MapSlices to preserve internal order
|
||||
resourceOrder := []string{}
|
||||
lookup := make(map[string]yaml.MapSlice)
|
||||
for _, item := range input {
|
||||
key, ok := item.Key.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected non-string key in input: %v", item.Key)
|
||||
}
|
||||
if slice, ok := item.Value.(yaml.MapSlice); ok {
|
||||
// convert authSources to authServices
|
||||
if key == "authSources" {
|
||||
key = "authServices"
|
||||
}
|
||||
// works even if lookup[key] is nil
|
||||
lookup[key] = append(lookup[key], slice...)
|
||||
// preserving the resource's order of original toolsFile
|
||||
if !slices.Contains(resourceOrder, key) {
|
||||
resourceOrder = append(resourceOrder, key)
|
||||
}
|
||||
} else {
|
||||
// toolsfile is already v2
|
||||
if key == "kind" {
|
||||
return raw, nil
|
||||
}
|
||||
return nil, fmt.Errorf("'%s' is not a map", key)
|
||||
}
|
||||
}
|
||||
// convert to tools file v2
|
||||
var buf bytes.Buffer
|
||||
encoder := yaml.NewEncoder(&buf)
|
||||
for _, kind := range resourceOrder {
|
||||
data, exists := lookup[kind]
|
||||
if !exists {
|
||||
// if this is skipped for all keys, the tools file is in v2
|
||||
continue
|
||||
}
|
||||
// Transform each entry
|
||||
for _, entry := range data {
|
||||
entryName, ok := entry.Key.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected non-string key for entry in '%s': %v", kind, entry.Key)
|
||||
}
|
||||
entryBody := ProcessValue(entry.Value, kind == "toolsets")
|
||||
|
||||
transformed := yaml.MapSlice{
|
||||
{Key: "kind", Value: kind},
|
||||
{Key: "name", Value: entryName},
|
||||
}
|
||||
|
||||
// Merge the transformed body into our result
|
||||
if bodySlice, ok := entryBody.(yaml.MapSlice); ok {
|
||||
transformed = append(transformed, bodySlice...)
|
||||
} else {
|
||||
return nil, fmt.Errorf("unable to convert entryBody to MapSlice")
|
||||
}
|
||||
|
||||
if err := encoder.Encode(transformed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// ProcessValue recursively looks for MapSlices to rename 'kind' -> 'type'
|
||||
func ProcessValue(v any, isToolset bool) any {
|
||||
switch val := v.(type) {
|
||||
case yaml.MapSlice:
|
||||
for i := range val {
|
||||
// Perform renaming
|
||||
if val[i].Key == "kind" {
|
||||
val[i].Key = "type"
|
||||
}
|
||||
// Recursive call for nested values (e.g., nested objects or lists)
|
||||
val[i].Value = ProcessValue(val[i].Value, false)
|
||||
}
|
||||
return val
|
||||
case []any:
|
||||
// Process lists: If it's a toolset top-level list, wrap it.
|
||||
if isToolset {
|
||||
return yaml.MapSlice{{Key: "tools", Value: val}}
|
||||
}
|
||||
// Otherwise, recurse into list items (to catch nested objects)
|
||||
for i := range val {
|
||||
val[i] = ProcessValue(val[i], false)
|
||||
}
|
||||
return val
|
||||
default:
|
||||
return val
|
||||
}
|
||||
}
|
||||
|
||||
// parseToolsFile parses the provided yaml into appropriate configs.
|
||||
func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
|
||||
var toolsFile ToolsFile
|
||||
@@ -435,8 +535,13 @@ func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
|
||||
}
|
||||
raw = []byte(output)
|
||||
|
||||
raw, err = convertToolsFile(ctx, raw)
|
||||
if err != nil {
|
||||
return toolsFile, fmt.Errorf("error converting tools file: %s", err)
|
||||
}
|
||||
|
||||
// Parse contents
|
||||
err = yaml.UnmarshalContext(ctx, raw, &toolsFile, yaml.Strict())
|
||||
toolsFile.Sources, toolsFile.AuthServices, toolsFile.EmbeddingModels, toolsFile.Tools, toolsFile.Toolsets, toolsFile.Prompts, err = server.UnmarshalResourceConfig(ctx, raw)
|
||||
if err != nil {
|
||||
return toolsFile, err
|
||||
}
|
||||
@@ -468,18 +573,6 @@ func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for conflicts and merge authSources (deprecated, but still support)
|
||||
for name, authSource := range file.AuthSources {
|
||||
if _, exists := merged.AuthSources[name]; exists {
|
||||
conflicts = append(conflicts, fmt.Sprintf("authSource '%s' (file #%d)", name, fileIndex+1))
|
||||
} else {
|
||||
if merged.AuthSources == nil {
|
||||
merged.AuthSources = make(server.AuthServiceConfigs)
|
||||
}
|
||||
merged.AuthSources[name] = authSource
|
||||
}
|
||||
}
|
||||
|
||||
// Check for conflicts and merge authServices
|
||||
for name, authService := range file.AuthServices {
|
||||
if _, exists := merged.AuthServices[name]; exists {
|
||||
@@ -955,20 +1048,6 @@ func run(cmd *Command) error {
|
||||
cmd.cfg.ToolsetConfigs = finalToolsFile.Toolsets
|
||||
cmd.cfg.PromptConfigs = finalToolsFile.Prompts
|
||||
|
||||
authSourceConfigs := finalToolsFile.AuthSources
|
||||
if authSourceConfigs != nil {
|
||||
cmd.logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` instead")
|
||||
|
||||
for k, v := range authSourceConfigs {
|
||||
if _, exists := cmd.cfg.AuthServiceConfigs[k]; exists {
|
||||
errMsg := fmt.Errorf("resource conflict detected: authSource '%s' has the same name as an existing authService. Please rename your authSource", k)
|
||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
cmd.cfg.AuthServiceConfigs[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
instrumentation, err := telemetry.CreateTelemetryInstrumentation(versionString)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("unable to create telemetry instrumentation: %w", err)
|
||||
|
||||
425
cmd/root_test.go
425
cmd/root_test.go
@@ -23,12 +23,14 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/auth/google"
|
||||
@@ -494,6 +496,309 @@ func TestDefaultLogLevel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToolsFile(t *testing.T) {
|
||||
ctx, cancelCtx := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancelCtx()
|
||||
pr, pw := io.Pipe()
|
||||
defer pw.Close()
|
||||
defer pr.Close()
|
||||
|
||||
logger, err := log.NewStdLogger(pw, pw, "DEBUG")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup logger %s", err)
|
||||
}
|
||||
ctx = util.WithLogger(ctx, logger)
|
||||
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want string
|
||||
isErr bool
|
||||
errStr string
|
||||
}{
|
||||
{
|
||||
desc: "basic convert",
|
||||
in: `
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
authServices:
|
||||
my-google-auth:
|
||||
kind: google
|
||||
clientId: testing-id
|
||||
tools:
|
||||
example_tool:
|
||||
kind: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
toolsets:
|
||||
example_toolset:
|
||||
- example_tool
|
||||
prompts:
|
||||
code_review:
|
||||
description: ask llm to analyze code quality
|
||||
messages:
|
||||
- content: "please review the following code for quality: {{.code}}"
|
||||
arguments:
|
||||
- name: code
|
||||
description: the code to review
|
||||
embeddingModels:
|
||||
gemini-model:
|
||||
kind: gemini
|
||||
model: gemini-embedding-001
|
||||
apiKey: some-key
|
||||
dimension: 768`,
|
||||
want: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
---
|
||||
kind: authServices
|
||||
name: my-google-auth
|
||||
type: google
|
||||
clientId: testing-id
|
||||
---
|
||||
kind: tools
|
||||
name: example_tool
|
||||
type: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
---
|
||||
kind: toolsets
|
||||
name: example_toolset
|
||||
tools:
|
||||
- example_tool
|
||||
---
|
||||
kind: prompts
|
||||
name: code_review
|
||||
description: ask llm to analyze code quality
|
||||
messages:
|
||||
- content: "please review the following code for quality: {{.code}}"
|
||||
arguments:
|
||||
- name: code
|
||||
description: the code to review
|
||||
---
|
||||
kind: embeddingModels
|
||||
name: gemini-model
|
||||
type: gemini
|
||||
model: gemini-embedding-001
|
||||
apiKey: some-key
|
||||
dimension: 768`,
|
||||
},
|
||||
{
|
||||
desc: "preserve resource order with grouping",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
authServices:
|
||||
my-google-auth:
|
||||
kind: google
|
||||
clientId: testing-id
|
||||
toolsets:
|
||||
example_toolset:
|
||||
- example_tool
|
||||
authSources:
|
||||
my-google-auth:
|
||||
kind: google
|
||||
clientId: testing-id`,
|
||||
want: `
|
||||
kind: tools
|
||||
name: example_tool
|
||||
type: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
---
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
---
|
||||
kind: authServices
|
||||
name: my-google-auth
|
||||
type: google
|
||||
clientId: testing-id
|
||||
---
|
||||
kind: authServices
|
||||
name: my-google-auth
|
||||
type: google
|
||||
clientId: testing-id
|
||||
---
|
||||
kind: toolsets
|
||||
name: example_toolset
|
||||
tools:
|
||||
- example_tool`,
|
||||
},
|
||||
{
|
||||
desc: "no convertion needed",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
---
|
||||
kind: tools
|
||||
name: example_tool
|
||||
type: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
---
|
||||
kind: toolsets
|
||||
name: example_toolset
|
||||
tools:
|
||||
- example_tool`,
|
||||
want: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
---
|
||||
kind: tools
|
||||
name: example_tool
|
||||
type: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
---
|
||||
kind: toolsets
|
||||
name: example_toolset
|
||||
tools:
|
||||
- example_tool`,
|
||||
},
|
||||
{
|
||||
desc: "invalid source",
|
||||
in: `sources: invalid`,
|
||||
isErr: true,
|
||||
errStr: "'sources' is not a map",
|
||||
},
|
||||
{
|
||||
desc: "invalid toolset",
|
||||
in: `toolsets: invalid`,
|
||||
isErr: true,
|
||||
errStr: "'toolsets' is not a map",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
output, err := convertToolsFile(ctx, []byte(tc.in))
|
||||
if tc.isErr {
|
||||
if err == nil {
|
||||
t.Fatalf("missing error: %s", tc.errStr)
|
||||
}
|
||||
if err.Error() != tc.errStr {
|
||||
t.Fatalf("invalid error string: got %s, want %s", err, tc.errStr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
var docs1, docs2 []yaml.MapSlice
|
||||
if docs1, err = decodeToMapSlice(string(output)); err != nil {
|
||||
t.Fatalf("error decoding output: %s", err)
|
||||
}
|
||||
if docs2, err = decodeToMapSlice(tc.want); err != nil {
|
||||
t.Fatalf("Error decoding want: %s", err)
|
||||
}
|
||||
if !reflect.DeepEqual(docs1, docs2) {
|
||||
t.Fatalf("incorrect output: got %s, want %s", string(output), tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func decodeToMapSlice(data string) ([]yaml.MapSlice, error) {
|
||||
// ensures that the order is correct
|
||||
var docs []yaml.MapSlice
|
||||
decoder := yaml.NewDecoder(strings.NewReader(data))
|
||||
for {
|
||||
var doc yaml.MapSlice
|
||||
err := decoder.Decode(&doc)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
func TestParseToolFile(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
@@ -505,7 +810,7 @@ func TestParseToolFile(t *testing.T) {
|
||||
wantToolsFile ToolsFile
|
||||
}{
|
||||
{
|
||||
description: "basic example",
|
||||
description: "basic example tools file v1",
|
||||
in: `
|
||||
sources:
|
||||
my-pg-instance:
|
||||
@@ -568,7 +873,121 @@ func TestParseToolFile(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "with prompts example",
|
||||
description: "basic example tools file v2",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
---
|
||||
kind: authServices
|
||||
name: my-google-auth
|
||||
type: google
|
||||
clientId: testing-id
|
||||
---
|
||||
kind: embeddingModels
|
||||
name: gemini-model
|
||||
type: gemini
|
||||
model: gemini-embedding-001
|
||||
apiKey: some-key
|
||||
dimension: 768
|
||||
---
|
||||
kind: tools
|
||||
name: example_tool
|
||||
type: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
---
|
||||
kind: toolsets
|
||||
name: example_toolset
|
||||
tools:
|
||||
- example_tool
|
||||
---
|
||||
kind: prompts
|
||||
name: code_review
|
||||
description: ask llm to analyze code quality
|
||||
messages:
|
||||
- content: "please review the following code for quality: {{.code}}"
|
||||
arguments:
|
||||
- name: code
|
||||
description: the code to review
|
||||
`,
|
||||
wantToolsFile: ToolsFile{
|
||||
Sources: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpgsrc.Config{
|
||||
Name: "my-pg-instance",
|
||||
Type: cloudsqlpgsrc.SourceType,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
IPType: "public",
|
||||
Database: "my_db",
|
||||
User: "my_user",
|
||||
Password: "my_pass",
|
||||
},
|
||||
},
|
||||
AuthServices: server.AuthServiceConfigs{
|
||||
"my-google-auth": google.Config{
|
||||
Name: "my-google-auth",
|
||||
Type: google.AuthServiceType,
|
||||
ClientID: "testing-id",
|
||||
},
|
||||
},
|
||||
EmbeddingModels: server.EmbeddingModelConfigs{
|
||||
"gemini-model": gemini.Config{
|
||||
Name: "gemini-model",
|
||||
Type: gemini.EmbeddingModelType,
|
||||
Model: "gemini-embedding-001",
|
||||
ApiKey: "some-key",
|
||||
Dimension: 768,
|
||||
},
|
||||
},
|
||||
Tools: server.ToolConfigs{
|
||||
"example_tool": postgressql.Config{
|
||||
Name: "example_tool",
|
||||
Type: "postgres-sql",
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
Parameters: []parameters.Parameter{
|
||||
parameters.NewStringParameter("country", "some description"),
|
||||
},
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
Toolsets: server.ToolsetConfigs{
|
||||
"example_toolset": tools.ToolsetConfig{
|
||||
Name: "example_toolset",
|
||||
ToolNames: []string{"example_tool"},
|
||||
},
|
||||
},
|
||||
Prompts: server.PromptConfigs{
|
||||
"code_review": custom.Config{
|
||||
Name: "code_review",
|
||||
Description: "ask llm to analyze code quality",
|
||||
Arguments: prompts.Arguments{
|
||||
{Parameter: parameters.NewStringParameter("code", "the code to review")},
|
||||
},
|
||||
Messages: []prompts.Message{
|
||||
{Role: "user", Content: "please review the following code for quality: {{.code}}"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "only prompts",
|
||||
in: `
|
||||
prompts:
|
||||
my-prompt:
|
||||
@@ -799,7 +1218,7 @@ func TestParseToolFileWithAuth(t *testing.T) {
|
||||
Password: "my_pass",
|
||||
},
|
||||
},
|
||||
AuthSources: server.AuthServiceConfigs{
|
||||
AuthServices: server.AuthServiceConfigs{
|
||||
"my-google-service": google.Config{
|
||||
Name: "my-google-service",
|
||||
Type: google.AuthServiceType,
|
||||
|
||||
0
cmd/test.db
Normal file
0
cmd/test.db
Normal file
@@ -14,8 +14,10 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
@@ -127,269 +129,211 @@ func (s *StringLevel) Type() string {
|
||||
// SourceConfigs is a type used to allow unmarshal of the data source config map
|
||||
type SourceConfigs map[string]sources.SourceConfig
|
||||
|
||||
// validate interface
|
||||
var _ yaml.InterfaceUnmarshalerContext = &SourceConfigs{}
|
||||
|
||||
func (c *SourceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
*c = make(SourceConfigs)
|
||||
// Parse the 'kind' fields for each source
|
||||
var raw map[string]util.DelayedUnmarshaler
|
||||
if err := unmarshal(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, u := range raw {
|
||||
// Unmarshal to a general type that ensure it capture all fields
|
||||
var v map[string]any
|
||||
if err := u.Unmarshal(&v); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal %q: %w", name, err)
|
||||
}
|
||||
|
||||
kind, ok := v["kind"]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'kind' field for source %q", name)
|
||||
}
|
||||
kindStr, ok := kind.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid 'kind' field for source %q (must be a string)", name)
|
||||
}
|
||||
|
||||
yamlDecoder, err := util.NewStrictDecoder(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating YAML decoder for source %q: %w", name, err)
|
||||
}
|
||||
|
||||
sourceConfig, err := sources.DecodeConfig(ctx, kindStr, name, yamlDecoder)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
(*c)[name] = sourceConfig
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthServiceConfigs is a type used to allow unmarshal of the data authService config map
|
||||
type AuthServiceConfigs map[string]auth.AuthServiceConfig
|
||||
|
||||
// validate interface
|
||||
var _ yaml.InterfaceUnmarshalerContext = &AuthServiceConfigs{}
|
||||
|
||||
func (c *AuthServiceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
*c = make(AuthServiceConfigs)
|
||||
// Parse the 'kind' fields for each authService
|
||||
var raw map[string]util.DelayedUnmarshaler
|
||||
if err := unmarshal(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, u := range raw {
|
||||
var v map[string]any
|
||||
if err := u.Unmarshal(&v); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal %q: %w", name, err)
|
||||
}
|
||||
|
||||
kind, ok := v["kind"]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'kind' field for %q", name)
|
||||
}
|
||||
|
||||
dec, err := util.NewStrictDecoder(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating decoder: %w", err)
|
||||
}
|
||||
switch kind {
|
||||
case google.AuthServiceType:
|
||||
actual := google.Config{Name: name}
|
||||
if err := dec.DecodeContext(ctx, &actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of auth source", kind)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EmbeddingModelConfigs is a type used to allow unmarshal of the embedding model config map
|
||||
type EmbeddingModelConfigs map[string]embeddingmodels.EmbeddingModelConfig
|
||||
|
||||
// validate interface
|
||||
var _ yaml.InterfaceUnmarshalerContext = &EmbeddingModelConfigs{}
|
||||
|
||||
func (c *EmbeddingModelConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
*c = make(EmbeddingModelConfigs)
|
||||
// Parse the 'kind' fields for each embedding model
|
||||
var raw map[string]util.DelayedUnmarshaler
|
||||
if err := unmarshal(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, u := range raw {
|
||||
// Unmarshal to a general type that ensure it capture all fields
|
||||
var v map[string]any
|
||||
if err := u.Unmarshal(&v); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal embedding model %q: %w", name, err)
|
||||
}
|
||||
|
||||
kind, ok := v["kind"]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'kind' field for embedding model %q", name)
|
||||
}
|
||||
|
||||
dec, err := util.NewStrictDecoder(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating decoder: %w", err)
|
||||
}
|
||||
switch kind {
|
||||
case gemini.EmbeddingModelType:
|
||||
actual := gemini.Config{Name: name}
|
||||
if err := dec.DecodeContext(ctx, &actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of auth source", kind)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToolConfigs is a type used to allow unmarshal of the tool configs
|
||||
type ToolConfigs map[string]tools.ToolConfig
|
||||
|
||||
// validate interface
|
||||
var _ yaml.InterfaceUnmarshalerContext = &ToolConfigs{}
|
||||
|
||||
func (c *ToolConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
*c = make(ToolConfigs)
|
||||
// Parse the 'kind' fields for each source
|
||||
var raw map[string]util.DelayedUnmarshaler
|
||||
if err := unmarshal(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, u := range raw {
|
||||
var v map[string]any
|
||||
if err := u.Unmarshal(&v); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal %q: %w", name, err)
|
||||
}
|
||||
|
||||
// `authRequired` and `useClientOAuth` cannot be specified together
|
||||
if v["authRequired"] != nil && v["useClientOAuth"] == true {
|
||||
return fmt.Errorf("`authRequired` and `useClientOAuth` are mutually exclusive. Choose only one authentication method")
|
||||
}
|
||||
|
||||
// Make `authRequired` an empty list instead of nil for Tool manifest
|
||||
if v["authRequired"] == nil {
|
||||
v["authRequired"] = []string{}
|
||||
}
|
||||
|
||||
kindVal, ok := v["kind"]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'kind' field for tool %q", name)
|
||||
}
|
||||
kindStr, ok := kindVal.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid 'kind' field for tool %q (must be a string)", name)
|
||||
}
|
||||
|
||||
yamlDecoder, err := util.NewStrictDecoder(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating YAML decoder for tool %q: %w", name, err)
|
||||
}
|
||||
|
||||
toolCfg, err := tools.DecodeConfig(ctx, kindStr, name, yamlDecoder)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
(*c)[name] = toolCfg
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToolsetConfigs is a type used to allow unmarshal of the toolset configs
|
||||
type ToolsetConfigs map[string]tools.ToolsetConfig
|
||||
|
||||
// validate interface
|
||||
var _ yaml.InterfaceUnmarshalerContext = &ToolsetConfigs{}
|
||||
|
||||
func (c *ToolsetConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
*c = make(ToolsetConfigs)
|
||||
|
||||
var raw map[string][]string
|
||||
if err := unmarshal(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, toolList := range raw {
|
||||
(*c)[name] = tools.ToolsetConfig{Name: name, ToolNames: toolList}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PromptConfigs is a type used to allow unmarshal of the prompt configs
|
||||
type PromptConfigs map[string]prompts.PromptConfig
|
||||
|
||||
// validate interface
|
||||
var _ yaml.InterfaceUnmarshalerContext = &PromptConfigs{}
|
||||
|
||||
func (c *PromptConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
*c = make(PromptConfigs)
|
||||
var raw map[string]util.DelayedUnmarshaler
|
||||
if err := unmarshal(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, u := range raw {
|
||||
var v map[string]any
|
||||
if err := u.Unmarshal(&v); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal prompt %q: %w", name, err)
|
||||
}
|
||||
|
||||
// Look for the 'kind' field. If it's not present, kindStr will be an
|
||||
// empty string, which prompts.DecodeConfig will correctly default to "custom".
|
||||
var kindStr string
|
||||
if kindVal, ok := v["kind"]; ok {
|
||||
var isString bool
|
||||
kindStr, isString = kindVal.(string)
|
||||
if !isString {
|
||||
return fmt.Errorf("invalid 'kind' field for prompt %q (must be a string)", name)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new, strict decoder for this specific prompt's data.
|
||||
yamlDecoder, err := util.NewStrictDecoder(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating YAML decoder for prompt %q: %w", name, err)
|
||||
}
|
||||
|
||||
// Use the central registry to decode the prompt based on its kind.
|
||||
promptCfg, err := prompts.DecodeConfig(ctx, kindStr, name, yamlDecoder)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
(*c)[name] = promptCfg
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PromptsetConfigs is a type used to allow unmarshal of the PromptsetConfigs configs
|
||||
// PromptConfigs is a type used to allow unmarshal of the prompt configs
|
||||
type PromptsetConfigs map[string]prompts.PromptsetConfig
|
||||
|
||||
// validate interface
|
||||
var _ yaml.InterfaceUnmarshalerContext = &PromptsetConfigs{}
|
||||
func UnmarshalResourceConfig(ctx context.Context, raw []byte) (SourceConfigs, AuthServiceConfigs, EmbeddingModelConfigs, ToolConfigs, ToolsetConfigs, PromptConfigs, error) {
|
||||
// prepare configs map
|
||||
sourceConfigs := make(map[string]sources.SourceConfig)
|
||||
authServiceConfigs := make(AuthServiceConfigs)
|
||||
embeddingModelConfigs := make(EmbeddingModelConfigs)
|
||||
toolConfigs := make(ToolConfigs)
|
||||
toolsetConfigs := make(ToolsetConfigs)
|
||||
promptConfigs := make(PromptConfigs)
|
||||
// promptset configs is not yet supported
|
||||
|
||||
func (c *PromptsetConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
*c = make(PromptsetConfigs)
|
||||
decoder := yaml.NewDecoder(bytes.NewReader(raw))
|
||||
// for loop to unmarshal documents with the `---` separator
|
||||
for {
|
||||
var resource map[string]any
|
||||
if err := decoder.DecodeContext(ctx, &resource); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unable to decode YAML document: %w", err)
|
||||
}
|
||||
var kind, name string
|
||||
var ok bool
|
||||
if kind, ok = resource["kind"].(string); !ok {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("missing 'kind' field or it is not a string")
|
||||
}
|
||||
if name, ok = resource["name"].(string); !ok {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("missing 'name' field or it is not a string")
|
||||
}
|
||||
// remove 'kind' from map for strict unmarshaling
|
||||
delete(resource, "kind")
|
||||
|
||||
var raw map[string][]string
|
||||
if err := unmarshal(&raw); err != nil {
|
||||
return err
|
||||
switch kind {
|
||||
case "sources":
|
||||
c, err := UnmarshalYAMLSourceConfig(ctx, name, resource)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
|
||||
}
|
||||
sourceConfigs[name] = c
|
||||
case "authServices":
|
||||
c, err := UnmarshalYAMLAuthServiceConfig(ctx, name, resource)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
|
||||
}
|
||||
authServiceConfigs[name] = c
|
||||
case "tools":
|
||||
c, err := UnmarshalYAMLToolConfig(ctx, name, resource)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
|
||||
}
|
||||
toolConfigs[name] = c
|
||||
case "toolsets":
|
||||
c, err := UnmarshalYAMLToolsetConfig(ctx, name, resource)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
|
||||
}
|
||||
toolsetConfigs[name] = c
|
||||
case "embeddingModels":
|
||||
c, err := UnmarshalYAMLEmbeddingModelConfig(ctx, name, resource)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
|
||||
}
|
||||
embeddingModelConfigs[name] = c
|
||||
case "prompts":
|
||||
c, err := UnmarshalYAMLPromptConfig(ctx, name, resource)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
|
||||
}
|
||||
promptConfigs[name] = c
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("invalid kind %s", kind)
|
||||
}
|
||||
}
|
||||
|
||||
for name, promptList := range raw {
|
||||
(*c)[name] = prompts.PromptsetConfig{Name: name, PromptNames: promptList}
|
||||
}
|
||||
return nil
|
||||
return sourceConfigs, authServiceConfigs, embeddingModelConfigs, toolConfigs, toolsetConfigs, promptConfigs, nil
|
||||
}
|
||||
|
||||
func UnmarshalYAMLSourceConfig(ctx context.Context, name string, r map[string]any) (sources.SourceConfig, error) {
|
||||
typeStr, ok := r["type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing 'type' field or it is not a string")
|
||||
}
|
||||
dec, err := util.NewStrictDecoder(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating decoder: %w", err)
|
||||
}
|
||||
sourceConfig, err := sources.DecodeConfig(ctx, typeStr, name, dec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sourceConfig, nil
|
||||
}
|
||||
|
||||
func UnmarshalYAMLAuthServiceConfig(ctx context.Context, name string, r map[string]any) (auth.AuthServiceConfig, error) {
|
||||
typeStr, ok := r["type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing 'type' field or it is not a string")
|
||||
}
|
||||
if typeStr != google.AuthServiceType {
|
||||
return nil, fmt.Errorf("%s is not a valid type of auth service", typeStr)
|
||||
}
|
||||
dec, err := util.NewStrictDecoder(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating decoder: %s", err)
|
||||
}
|
||||
actual := google.Config{Name: name}
|
||||
if err := dec.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, fmt.Errorf("unable to parse as %s: %w", name, err)
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
func UnmarshalYAMLEmbeddingModelConfig(ctx context.Context, name string, r map[string]any) (embeddingmodels.EmbeddingModelConfig, error) {
|
||||
typeStr, ok := r["type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing 'type' field or it is not a string")
|
||||
}
|
||||
if typeStr != gemini.EmbeddingModelType {
|
||||
return nil, fmt.Errorf("%s is not a valid type of embedding model", typeStr)
|
||||
}
|
||||
dec, err := util.NewStrictDecoder(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating decoder: %s", err)
|
||||
}
|
||||
actual := gemini.Config{Name: name}
|
||||
if err := dec.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, fmt.Errorf("unable to parse as %q: %w", name, err)
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
func UnmarshalYAMLToolConfig(ctx context.Context, name string, r map[string]any) (tools.ToolConfig, error) {
|
||||
typeStr, ok := r["type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing 'type' field or it is not a string")
|
||||
}
|
||||
// `authRequired` and `useClientOAuth` cannot be specified together
|
||||
if r["authRequired"] != nil && r["useClientOAuth"] == true {
|
||||
return nil, fmt.Errorf("`authRequired` and `useClientOAuth` are mutually exclusive. Choose only one authentication method")
|
||||
}
|
||||
// Make `authRequired` an empty list instead of nil for Tool manifest
|
||||
if r["authRequired"] == nil {
|
||||
r["authRequired"] = []string{}
|
||||
}
|
||||
dec, err := util.NewStrictDecoder(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating decoder: %s", err)
|
||||
}
|
||||
toolCfg, err := tools.DecodeConfig(ctx, typeStr, name, dec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toolCfg, nil
|
||||
}
|
||||
|
||||
func UnmarshalYAMLToolsetConfig(ctx context.Context, name string, r map[string]any) (tools.ToolsetConfig, error) {
|
||||
var toolsetConfig tools.ToolsetConfig
|
||||
justTools := map[string]any{"tools": r["tools"]}
|
||||
dec, err := util.NewStrictDecoder(justTools)
|
||||
if err != nil {
|
||||
return toolsetConfig, fmt.Errorf("error creating decoder: %s", err)
|
||||
}
|
||||
var raw map[string][]string
|
||||
if err := dec.DecodeContext(ctx, &raw); err != nil {
|
||||
return toolsetConfig, fmt.Errorf("unable to unmarshal tools: %s", err)
|
||||
}
|
||||
return tools.ToolsetConfig{Name: name, ToolNames: raw["tools"]}, nil
|
||||
}
|
||||
|
||||
func UnmarshalYAMLPromptConfig(ctx context.Context, name string, r map[string]any) (prompts.PromptConfig, error) {
|
||||
// Look for the 'kind' field. If it's not present, kindStr will be an
|
||||
// empty string, which prompts.DecodeConfig will correctly default to "custom".
|
||||
var typeStr string
|
||||
if typeVal, ok := r["type"]; ok {
|
||||
var isString bool
|
||||
typeStr, isString = typeVal.(string)
|
||||
if !isString {
|
||||
return nil, fmt.Errorf("invalid 'type' field for prompt %q (must be a string)", name)
|
||||
}
|
||||
}
|
||||
dec, err := util.NewStrictDecoder(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating decoder: %s", err)
|
||||
}
|
||||
|
||||
// Use the central registry to decode the prompt based on its kind.
|
||||
promptCfg, err := prompts.DecodeConfig(ctx, typeStr, name, dec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return promptCfg, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user