Compare commits

...

1 Commits

Author SHA1 Message Date
Yuan Teoh
c29355ff82 chore: update unmarshal function for ToolsFile 2026-01-21 22:49:06 -08:00
3 changed files with 312 additions and 276 deletions

View File

@@ -395,7 +395,6 @@ func NewCommand(opts ...Option) *Command {
type ToolsFile struct {
Sources server.SourceConfigs `yaml:"sources"`
AuthSources server.AuthServiceConfigs `yaml:"authSources"` // Deprecated: Kept for compatibility.
AuthServices server.AuthServiceConfigs `yaml:"authServices"`
EmbeddingModels server.EmbeddingModelConfigs `yaml:"embeddingModels"`
Tools server.ToolConfigs `yaml:"tools"`
@@ -536,8 +535,13 @@ func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
}
raw = []byte(output)
raw, err = convertToolsFile(ctx, raw)
if err != nil {
return toolsFile, fmt.Errorf("error converting tools file: %s", err)
}
// Parse contents
err = yaml.UnmarshalContext(ctx, raw, &toolsFile, yaml.Strict())
toolsFile.Sources, toolsFile.AuthServices, toolsFile.EmbeddingModels, toolsFile.Tools, toolsFile.Toolsets, toolsFile.Prompts, err = server.UnmarshalResourceConfig(ctx, raw)
if err != nil {
return toolsFile, err
}
@@ -569,18 +573,6 @@ func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) {
}
}
// Check for conflicts and merge authSources (deprecated, but still support)
for name, authSource := range file.AuthSources {
if _, exists := merged.AuthSources[name]; exists {
conflicts = append(conflicts, fmt.Sprintf("authSource '%s' (file #%d)", name, fileIndex+1))
} else {
if merged.AuthSources == nil {
merged.AuthSources = make(server.AuthServiceConfigs)
}
merged.AuthSources[name] = authSource
}
}
// Check for conflicts and merge authServices
for name, authService := range file.AuthServices {
if _, exists := merged.AuthServices[name]; exists {
@@ -1056,20 +1048,6 @@ func run(cmd *Command) error {
cmd.cfg.ToolsetConfigs = finalToolsFile.Toolsets
cmd.cfg.PromptConfigs = finalToolsFile.Prompts
authSourceConfigs := finalToolsFile.AuthSources
if authSourceConfigs != nil {
cmd.logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` instead")
for k, v := range authSourceConfigs {
if _, exists := cmd.cfg.AuthServiceConfigs[k]; exists {
errMsg := fmt.Errorf("resource conflict detected: authSource '%s' has the same name as an existing authService. Please rename your authSource", k)
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
cmd.cfg.AuthServiceConfigs[k] = v
}
}
instrumentation, err := telemetry.CreateTelemetryInstrumentation(versionString)
if err != nil {
errMsg := fmt.Errorf("unable to create telemetry instrumentation: %w", err)

View File

@@ -810,7 +810,7 @@ func TestParseToolFile(t *testing.T) {
wantToolsFile ToolsFile
}{
{
description: "basic example",
description: "basic example tools file v1",
in: `
sources:
my-pg-instance:
@@ -873,7 +873,121 @@ func TestParseToolFile(t *testing.T) {
},
},
{
description: "with prompts example",
description: "basic example tools file v2",
in: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
---
kind: authServices
name: my-google-auth
type: google
clientId: testing-id
---
kind: embeddingModels
name: gemini-model
type: gemini
model: gemini-embedding-001
apiKey: some-key
dimension: 768
---
kind: tools
name: example_tool
type: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
---
kind: toolsets
name: example_toolset
tools:
- example_tool
---
kind: prompts
name: code_review
description: ask llm to analyze code quality
messages:
- content: "please review the following code for quality: {{.code}}"
arguments:
- name: code
description: the code to review
`,
wantToolsFile: ToolsFile{
Sources: server.SourceConfigs{
"my-pg-instance": cloudsqlpgsrc.Config{
Name: "my-pg-instance",
Type: cloudsqlpgsrc.SourceType,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
IPType: "public",
Database: "my_db",
User: "my_user",
Password: "my_pass",
},
},
AuthServices: server.AuthServiceConfigs{
"my-google-auth": google.Config{
Name: "my-google-auth",
Type: google.AuthServiceType,
ClientID: "testing-id",
},
},
EmbeddingModels: server.EmbeddingModelConfigs{
"gemini-model": gemini.Config{
Name: "gemini-model",
Type: gemini.EmbeddingModelType,
Model: "gemini-embedding-001",
ApiKey: "some-key",
Dimension: 768,
},
},
Tools: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Type: "postgres-sql",
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
Parameters: []parameters.Parameter{
parameters.NewStringParameter("country", "some description"),
},
AuthRequired: []string{},
},
},
Toolsets: server.ToolsetConfigs{
"example_toolset": tools.ToolsetConfig{
Name: "example_toolset",
ToolNames: []string{"example_tool"},
},
},
Prompts: server.PromptConfigs{
"code_review": custom.Config{
Name: "code_review",
Description: "ask llm to analyze code quality",
Arguments: prompts.Arguments{
{Parameter: parameters.NewStringParameter("code", "the code to review")},
},
Messages: []prompts.Message{
{Role: "user", Content: "please review the following code for quality: {{.code}}"},
},
},
},
},
},
{
description: "only prompts",
in: `
prompts:
my-prompt:
@@ -1104,7 +1218,7 @@ func TestParseToolFileWithAuth(t *testing.T) {
Password: "my_pass",
},
},
AuthSources: server.AuthServiceConfigs{
AuthServices: server.AuthServiceConfigs{
"my-google-service": google.Config{
Name: "my-google-service",
Type: google.AuthServiceType,

View File

@@ -14,8 +14,10 @@
package server
import (
"bytes"
"context"
"fmt"
"io"
"strings"
yaml "github.com/goccy/go-yaml"
@@ -127,269 +129,211 @@ func (s *StringLevel) Type() string {
// SourceConfigs is a type used to allow unmarshal of the data source config map
type SourceConfigs map[string]sources.SourceConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &SourceConfigs{}
func (c *SourceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(SourceConfigs)
// Parse the 'kind' fields for each source
var raw map[string]util.DelayedUnmarshaler
if err := unmarshal(&raw); err != nil {
return err
}
for name, u := range raw {
// Unmarshal to a general type that ensure it capture all fields
var v map[string]any
if err := u.Unmarshal(&v); err != nil {
return fmt.Errorf("unable to unmarshal %q: %w", name, err)
}
kind, ok := v["kind"]
if !ok {
return fmt.Errorf("missing 'kind' field for source %q", name)
}
kindStr, ok := kind.(string)
if !ok {
return fmt.Errorf("invalid 'kind' field for source %q (must be a string)", name)
}
yamlDecoder, err := util.NewStrictDecoder(v)
if err != nil {
return fmt.Errorf("error creating YAML decoder for source %q: %w", name, err)
}
sourceConfig, err := sources.DecodeConfig(ctx, kindStr, name, yamlDecoder)
if err != nil {
return err
}
(*c)[name] = sourceConfig
}
return nil
}
// AuthServiceConfigs is a type used to allow unmarshal of the data authService config map
type AuthServiceConfigs map[string]auth.AuthServiceConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &AuthServiceConfigs{}
func (c *AuthServiceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(AuthServiceConfigs)
// Parse the 'kind' fields for each authService
var raw map[string]util.DelayedUnmarshaler
if err := unmarshal(&raw); err != nil {
return err
}
for name, u := range raw {
var v map[string]any
if err := u.Unmarshal(&v); err != nil {
return fmt.Errorf("unable to unmarshal %q: %w", name, err)
}
kind, ok := v["kind"]
if !ok {
return fmt.Errorf("missing 'kind' field for %q", name)
}
dec, err := util.NewStrictDecoder(v)
if err != nil {
return fmt.Errorf("error creating decoder: %w", err)
}
switch kind {
case google.AuthServiceType:
actual := google.Config{Name: name}
if err := dec.DecodeContext(ctx, &actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of auth source", kind)
}
}
return nil
}
// EmbeddingModelConfigs is a type used to allow unmarshal of the embedding model config map
type EmbeddingModelConfigs map[string]embeddingmodels.EmbeddingModelConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &EmbeddingModelConfigs{}
func (c *EmbeddingModelConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(EmbeddingModelConfigs)
// Parse the 'kind' fields for each embedding model
var raw map[string]util.DelayedUnmarshaler
if err := unmarshal(&raw); err != nil {
return err
}
for name, u := range raw {
// Unmarshal to a general type that ensure it capture all fields
var v map[string]any
if err := u.Unmarshal(&v); err != nil {
return fmt.Errorf("unable to unmarshal embedding model %q: %w", name, err)
}
kind, ok := v["kind"]
if !ok {
return fmt.Errorf("missing 'kind' field for embedding model %q", name)
}
dec, err := util.NewStrictDecoder(v)
if err != nil {
return fmt.Errorf("error creating decoder: %w", err)
}
switch kind {
case gemini.EmbeddingModelType:
actual := gemini.Config{Name: name}
if err := dec.DecodeContext(ctx, &actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of auth source", kind)
}
}
return nil
}
// ToolConfigs is a type used to allow unmarshal of the tool configs
type ToolConfigs map[string]tools.ToolConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &ToolConfigs{}
func (c *ToolConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(ToolConfigs)
// Parse the 'kind' fields for each source
var raw map[string]util.DelayedUnmarshaler
if err := unmarshal(&raw); err != nil {
return err
}
for name, u := range raw {
var v map[string]any
if err := u.Unmarshal(&v); err != nil {
return fmt.Errorf("unable to unmarshal %q: %w", name, err)
}
// `authRequired` and `useClientOAuth` cannot be specified together
if v["authRequired"] != nil && v["useClientOAuth"] == true {
return fmt.Errorf("`authRequired` and `useClientOAuth` are mutually exclusive. Choose only one authentication method")
}
// Make `authRequired` an empty list instead of nil for Tool manifest
if v["authRequired"] == nil {
v["authRequired"] = []string{}
}
kindVal, ok := v["kind"]
if !ok {
return fmt.Errorf("missing 'kind' field for tool %q", name)
}
kindStr, ok := kindVal.(string)
if !ok {
return fmt.Errorf("invalid 'kind' field for tool %q (must be a string)", name)
}
yamlDecoder, err := util.NewStrictDecoder(v)
if err != nil {
return fmt.Errorf("error creating YAML decoder for tool %q: %w", name, err)
}
toolCfg, err := tools.DecodeConfig(ctx, kindStr, name, yamlDecoder)
if err != nil {
return err
}
(*c)[name] = toolCfg
}
return nil
}
// ToolsetConfigs is a type used to allow unmarshal of the toolset configs
type ToolsetConfigs map[string]tools.ToolsetConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &ToolsetConfigs{}
func (c *ToolsetConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(ToolsetConfigs)
var raw map[string][]string
if err := unmarshal(&raw); err != nil {
return err
}
for name, toolList := range raw {
(*c)[name] = tools.ToolsetConfig{Name: name, ToolNames: toolList}
}
return nil
}
// PromptConfigs is a type used to allow unmarshal of the prompt configs
type PromptConfigs map[string]prompts.PromptConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &PromptConfigs{}
func (c *PromptConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(PromptConfigs)
var raw map[string]util.DelayedUnmarshaler
if err := unmarshal(&raw); err != nil {
return err
}
for name, u := range raw {
var v map[string]any
if err := u.Unmarshal(&v); err != nil {
return fmt.Errorf("unable to unmarshal prompt %q: %w", name, err)
}
// Look for the 'kind' field. If it's not present, kindStr will be an
// empty string, which prompts.DecodeConfig will correctly default to "custom".
var kindStr string
if kindVal, ok := v["kind"]; ok {
var isString bool
kindStr, isString = kindVal.(string)
if !isString {
return fmt.Errorf("invalid 'kind' field for prompt %q (must be a string)", name)
}
}
// Create a new, strict decoder for this specific prompt's data.
yamlDecoder, err := util.NewStrictDecoder(v)
if err != nil {
return fmt.Errorf("error creating YAML decoder for prompt %q: %w", name, err)
}
// Use the central registry to decode the prompt based on its kind.
promptCfg, err := prompts.DecodeConfig(ctx, kindStr, name, yamlDecoder)
if err != nil {
return err
}
(*c)[name] = promptCfg
}
return nil
}
// PromptsetConfigs is a type used to allow unmarshal of the PromptsetConfigs configs
// PromptConfigs is a type used to allow unmarshal of the prompt configs
type PromptsetConfigs map[string]prompts.PromptsetConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &PromptsetConfigs{}
func UnmarshalResourceConfig(ctx context.Context, raw []byte) (SourceConfigs, AuthServiceConfigs, EmbeddingModelConfigs, ToolConfigs, ToolsetConfigs, PromptConfigs, error) {
// prepare configs map
sourceConfigs := make(map[string]sources.SourceConfig)
authServiceConfigs := make(AuthServiceConfigs)
embeddingModelConfigs := make(EmbeddingModelConfigs)
toolConfigs := make(ToolConfigs)
toolsetConfigs := make(ToolsetConfigs)
promptConfigs := make(PromptConfigs)
// promptset configs is not yet supported
func (c *PromptsetConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(PromptsetConfigs)
decoder := yaml.NewDecoder(bytes.NewReader(raw))
// for loop to unmarshal documents with the `---` separator
for {
var resource map[string]any
if err := decoder.DecodeContext(ctx, &resource); err != nil {
if err == io.EOF {
break
}
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unable to decode YAML document: %w", err)
}
var kind, name string
var ok bool
if kind, ok = resource["kind"].(string); !ok {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("missing 'kind' field or it is not a string")
}
if name, ok = resource["name"].(string); !ok {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("missing 'name' field or it is not a string")
}
// remove 'kind' from map for strict unmarshaling
delete(resource, "kind")
var raw map[string][]string
if err := unmarshal(&raw); err != nil {
return err
switch kind {
case "sources":
c, err := UnmarshalYAMLSourceConfig(ctx, name, resource)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
}
sourceConfigs[name] = c
case "authServices":
c, err := UnmarshalYAMLAuthServiceConfig(ctx, name, resource)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
}
authServiceConfigs[name] = c
case "tools":
c, err := UnmarshalYAMLToolConfig(ctx, name, resource)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
}
toolConfigs[name] = c
case "toolsets":
c, err := UnmarshalYAMLToolsetConfig(ctx, name, resource)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
}
toolsetConfigs[name] = c
case "embeddingModels":
c, err := UnmarshalYAMLEmbeddingModelConfig(ctx, name, resource)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
}
embeddingModelConfigs[name] = c
case "prompts":
c, err := UnmarshalYAMLPromptConfig(ctx, name, resource)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err)
}
promptConfigs[name] = c
default:
return nil, nil, nil, nil, nil, nil, fmt.Errorf("invalid kind %s", kind)
}
}
for name, promptList := range raw {
(*c)[name] = prompts.PromptsetConfig{Name: name, PromptNames: promptList}
}
return nil
return sourceConfigs, authServiceConfigs, embeddingModelConfigs, toolConfigs, toolsetConfigs, promptConfigs, nil
}
func UnmarshalYAMLSourceConfig(ctx context.Context, name string, r map[string]any) (sources.SourceConfig, error) {
typeStr, ok := r["type"].(string)
if !ok {
return nil, fmt.Errorf("missing 'type' field or it is not a string")
}
dec, err := util.NewStrictDecoder(r)
if err != nil {
return nil, fmt.Errorf("error creating decoder: %w", err)
}
sourceConfig, err := sources.DecodeConfig(ctx, typeStr, name, dec)
if err != nil {
return nil, err
}
return sourceConfig, nil
}
func UnmarshalYAMLAuthServiceConfig(ctx context.Context, name string, r map[string]any) (auth.AuthServiceConfig, error) {
typeStr, ok := r["type"].(string)
if !ok {
return nil, fmt.Errorf("missing 'type' field or it is not a string")
}
if typeStr != google.AuthServiceType {
return nil, fmt.Errorf("%s is not a valid type of auth service", typeStr)
}
dec, err := util.NewStrictDecoder(r)
if err != nil {
return nil, fmt.Errorf("error creating decoder: %s", err)
}
actual := google.Config{Name: name}
if err := dec.DecodeContext(ctx, &actual); err != nil {
return nil, fmt.Errorf("unable to parse as %s: %w", name, err)
}
return actual, nil
}
func UnmarshalYAMLEmbeddingModelConfig(ctx context.Context, name string, r map[string]any) (embeddingmodels.EmbeddingModelConfig, error) {
typeStr, ok := r["type"].(string)
if !ok {
return nil, fmt.Errorf("missing 'type' field or it is not a string")
}
if typeStr != gemini.EmbeddingModelType {
return nil, fmt.Errorf("%s is not a valid type of embedding model", typeStr)
}
dec, err := util.NewStrictDecoder(r)
if err != nil {
return nil, fmt.Errorf("error creating decoder: %s", err)
}
actual := gemini.Config{Name: name}
if err := dec.DecodeContext(ctx, &actual); err != nil {
return nil, fmt.Errorf("unable to parse as %q: %w", name, err)
}
return actual, nil
}
func UnmarshalYAMLToolConfig(ctx context.Context, name string, r map[string]any) (tools.ToolConfig, error) {
typeStr, ok := r["type"].(string)
if !ok {
return nil, fmt.Errorf("missing 'type' field or it is not a string")
}
// `authRequired` and `useClientOAuth` cannot be specified together
if r["authRequired"] != nil && r["useClientOAuth"] == true {
return nil, fmt.Errorf("`authRequired` and `useClientOAuth` are mutually exclusive. Choose only one authentication method")
}
// Make `authRequired` an empty list instead of nil for Tool manifest
if r["authRequired"] == nil {
r["authRequired"] = []string{}
}
dec, err := util.NewStrictDecoder(r)
if err != nil {
return nil, fmt.Errorf("error creating decoder: %s", err)
}
toolCfg, err := tools.DecodeConfig(ctx, typeStr, name, dec)
if err != nil {
return nil, err
}
return toolCfg, nil
}
func UnmarshalYAMLToolsetConfig(ctx context.Context, name string, r map[string]any) (tools.ToolsetConfig, error) {
var toolsetConfig tools.ToolsetConfig
justTools := map[string]any{"tools": r["tools"]}
dec, err := util.NewStrictDecoder(justTools)
if err != nil {
return toolsetConfig, fmt.Errorf("error creating decoder: %s", err)
}
var raw map[string][]string
if err := dec.DecodeContext(ctx, &raw); err != nil {
return toolsetConfig, fmt.Errorf("unable to unmarshal tools: %s", err)
}
return tools.ToolsetConfig{Name: name, ToolNames: raw["tools"]}, nil
}
func UnmarshalYAMLPromptConfig(ctx context.Context, name string, r map[string]any) (prompts.PromptConfig, error) {
// Look for the 'kind' field. If it's not present, kindStr will be an
// empty string, which prompts.DecodeConfig will correctly default to "custom".
var typeStr string
if typeVal, ok := r["type"]; ok {
var isString bool
typeStr, isString = typeVal.(string)
if !isString {
return nil, fmt.Errorf("invalid 'type' field for prompt %q (must be a string)", name)
}
}
dec, err := util.NewStrictDecoder(r)
if err != nil {
return nil, fmt.Errorf("error creating decoder: %s", err)
}
// Use the central registry to decode the prompt based on its kind.
promptCfg, err := prompts.DecodeConfig(ctx, typeStr, name, dec)
if err != nil {
return nil, err
}
return promptCfg, nil
}