From f57352fcaafd8d3eb7cb2b2fda4366b453773a99 Mon Sep 17 00:00:00 2001 From: Yuan Teoh <45984206+Yuan325@users.noreply.github.com> Date: Mon, 26 Jan 2026 17:40:32 -0800 Subject: [PATCH] chore: update unmarshal function for ToolsFile (#2344) Update `parseToolsFile()` to use the file conversion function. Update unmarshalling function since we will be unmarshalling per-doc (separated by `---`). `AuthSources` will be converted to `AuthServices` during the file conversion stage. Hence, removing it from `ToolsFile` struct. ToolsFile v2 will not support `AuthSources`. Double checked that all docs do not reference to `AuthSources` anymore; new users is expected to be using `AuthServices` instead. This PR will not pass the unit test since the updates for resources's yaml tag will be in subsequent PRs. Breaking it down to keep review simpler. Related #817 Upcoming PR: * Update yaml tag --- cmd/root.go | 36 +-- cmd/root_test.go | 120 ++++++++- internal/server/config.go | 512 +++++++++++++++++--------------------- 3 files changed, 353 insertions(+), 315 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 2a3f21e919..ce32662cea 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -397,7 +397,6 @@ func NewCommand(opts ...Option) *Command { type ToolsFile struct { Sources server.SourceConfigs `yaml:"sources"` - AuthSources server.AuthServiceConfigs `yaml:"authSources"` // Deprecated: Kept for compatibility. AuthServices server.AuthServiceConfigs `yaml:"authServices"` EmbeddingModels server.EmbeddingModelConfigs `yaml:"embeddingModels"` Tools server.ToolConfigs `yaml:"tools"` @@ -455,7 +454,7 @@ func convertToolsFile(raw []byte) ([]byte, error) { // fields such as "tools" in toolsets might pass the first check but // fail to convert to MapSlice if slice, ok := item.Value.(yaml.MapSlice); ok { - // convert authSources to authServices + // Deprecated: convert authSources to authServices if key == "authSources" { key = "authServices" } @@ -561,8 +560,13 @@ func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) { } raw = []byte(output) + raw, err = convertToolsFile(raw) + if err != nil { + return toolsFile, fmt.Errorf("error converting tools file: %s", err) + } + // Parse contents - err = yaml.UnmarshalContext(ctx, raw, &toolsFile, yaml.Strict()) + toolsFile.Sources, toolsFile.AuthServices, toolsFile.EmbeddingModels, toolsFile.Tools, toolsFile.Toolsets, toolsFile.Prompts, err = server.UnmarshalResourceConfig(ctx, raw) if err != nil { return toolsFile, err } @@ -594,18 +598,6 @@ func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) { } } - // Check for conflicts and merge authSources (deprecated, but still support) - for name, authSource := range file.AuthSources { - if _, exists := merged.AuthSources[name]; exists { - conflicts = append(conflicts, fmt.Sprintf("authSource '%s' (file #%d)", name, fileIndex+1)) - } else { - if merged.AuthSources == nil { - merged.AuthSources = make(server.AuthServiceConfigs) - } - merged.AuthSources[name] = authSource - } - } - // Check for conflicts and merge authServices for name, authService := range file.AuthServices { if _, exists := merged.AuthServices[name]; exists { @@ -1089,20 +1081,6 @@ func run(cmd *Command) error { cmd.cfg.ToolsetConfigs = finalToolsFile.Toolsets cmd.cfg.PromptConfigs = finalToolsFile.Prompts - authSourceConfigs := finalToolsFile.AuthSources - if authSourceConfigs != nil { - cmd.logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` instead") - - for k, v := range authSourceConfigs { - if _, exists := cmd.cfg.AuthServiceConfigs[k]; exists { - errMsg := fmt.Errorf("resource conflict detected: authSource '%s' has the same name as an existing authService. Please rename your authSource", k) - cmd.logger.ErrorContext(ctx, errMsg.Error()) - return errMsg - } - cmd.cfg.AuthServiceConfigs[k] = v - } - } - instrumentation, err := telemetry.CreateTelemetryInstrumentation(versionString) if err != nil { errMsg := fmt.Errorf("unable to create telemetry instrumentation: %w", err) diff --git a/cmd/root_test.go b/cmd/root_test.go index ac8e0c4e4e..a1024c605a 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -956,7 +956,7 @@ func TestParseToolFile(t *testing.T) { wantToolsFile ToolsFile }{ { - description: "basic example", + description: "basic example tools file v1", in: ` sources: my-pg-instance: @@ -1019,7 +1019,121 @@ func TestParseToolFile(t *testing.T) { }, }, { - description: "with prompts example", + description: "basic example tools file v2", + in: ` + kind: sources + name: my-pg-instance + type: cloud-sql-postgres + project: my-project + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass +--- + kind: authServices + name: my-google-auth + type: google + clientId: testing-id +--- + kind: embeddingModels + name: gemini-model + type: gemini + model: gemini-embedding-001 + apiKey: some-key + dimension: 768 +--- + kind: tools + name: example_tool + type: postgres-sql + source: my-pg-instance + description: some description + statement: | + SELECT * FROM SQL_STATEMENT; + parameters: + - name: country + type: string + description: some description +--- + kind: toolsets + name: example_toolset + tools: + - example_tool +--- + kind: prompts + name: code_review + description: ask llm to analyze code quality + messages: + - content: "please review the following code for quality: {{.code}}" + arguments: + - name: code + description: the code to review + `, + wantToolsFile: ToolsFile{ + Sources: server.SourceConfigs{ + "my-pg-instance": cloudsqlpgsrc.Config{ + Name: "my-pg-instance", + Type: cloudsqlpgsrc.SourceType, + Project: "my-project", + Region: "my-region", + Instance: "my-instance", + IPType: "public", + Database: "my_db", + User: "my_user", + Password: "my_pass", + }, + }, + AuthServices: server.AuthServiceConfigs{ + "my-google-auth": google.Config{ + Name: "my-google-auth", + Type: google.AuthServiceType, + ClientID: "testing-id", + }, + }, + EmbeddingModels: server.EmbeddingModelConfigs{ + "gemini-model": gemini.Config{ + Name: "gemini-model", + Type: gemini.EmbeddingModelType, + Model: "gemini-embedding-001", + ApiKey: "some-key", + Dimension: 768, + }, + }, + Tools: server.ToolConfigs{ + "example_tool": postgressql.Config{ + Name: "example_tool", + Type: "postgres-sql", + Source: "my-pg-instance", + Description: "some description", + Statement: "SELECT * FROM SQL_STATEMENT;\n", + Parameters: []parameters.Parameter{ + parameters.NewStringParameter("country", "some description"), + }, + AuthRequired: []string{}, + }, + }, + Toolsets: server.ToolsetConfigs{ + "example_toolset": tools.ToolsetConfig{ + Name: "example_toolset", + ToolNames: []string{"example_tool"}, + }, + }, + Prompts: server.PromptConfigs{ + "code_review": custom.Config{ + Name: "code_review", + Description: "ask llm to analyze code quality", + Arguments: prompts.Arguments{ + {Parameter: parameters.NewStringParameter("code", "the code to review")}, + }, + Messages: []prompts.Message{ + {Role: "user", Content: "please review the following code for quality: {{.code}}"}, + }, + }, + }, + }, + }, + { + description: "only prompts", in: ` prompts: my-prompt: @@ -1250,7 +1364,7 @@ func TestParseToolFileWithAuth(t *testing.T) { Password: "my_pass", }, }, - AuthSources: server.AuthServiceConfigs{ + AuthServices: server.AuthServiceConfigs{ "my-google-service": google.Config{ Name: "my-google-service", Type: google.AuthServiceType, diff --git a/internal/server/config.go b/internal/server/config.go index 1cec76f80b..9d387448f3 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -14,8 +14,10 @@ package server import ( + "bytes" "context" "fmt" + "io" "regexp" "strings" @@ -130,312 +132,256 @@ func (s *StringLevel) Type() string { // SourceConfigs is a type used to allow unmarshal of the data source config map type SourceConfigs map[string]sources.SourceConfig -// validate interface -var _ yaml.InterfaceUnmarshalerContext = &SourceConfigs{} - -func (c *SourceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error { - *c = make(SourceConfigs) - // Parse the 'kind' fields for each source - var raw map[string]util.DelayedUnmarshaler - if err := unmarshal(&raw); err != nil { - return err - } - - for name, u := range raw { - // Unmarshal to a general type that ensure it capture all fields - var v map[string]any - if err := u.Unmarshal(&v); err != nil { - return fmt.Errorf("unable to unmarshal %q: %w", name, err) - } - - kind, ok := v["kind"] - if !ok { - return fmt.Errorf("missing 'kind' field for source %q", name) - } - kindStr, ok := kind.(string) - if !ok { - return fmt.Errorf("invalid 'kind' field for source %q (must be a string)", name) - } - - yamlDecoder, err := util.NewStrictDecoder(v) - if err != nil { - return fmt.Errorf("error creating YAML decoder for source %q: %w", name, err) - } - - sourceConfig, err := sources.DecodeConfig(ctx, kindStr, name, yamlDecoder) - if err != nil { - return err - } - (*c)[name] = sourceConfig - } - return nil -} - // AuthServiceConfigs is a type used to allow unmarshal of the data authService config map type AuthServiceConfigs map[string]auth.AuthServiceConfig -// validate interface -var _ yaml.InterfaceUnmarshalerContext = &AuthServiceConfigs{} - -func (c *AuthServiceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error { - *c = make(AuthServiceConfigs) - // Parse the 'kind' fields for each authService - var raw map[string]util.DelayedUnmarshaler - if err := unmarshal(&raw); err != nil { - return err - } - - for name, u := range raw { - var v map[string]any - if err := u.Unmarshal(&v); err != nil { - return fmt.Errorf("unable to unmarshal %q: %w", name, err) - } - - kind, ok := v["kind"] - if !ok { - return fmt.Errorf("missing 'kind' field for %q", name) - } - - dec, err := util.NewStrictDecoder(v) - if err != nil { - return fmt.Errorf("error creating decoder: %w", err) - } - switch kind { - case google.AuthServiceType: - actual := google.Config{Name: name} - if err := dec.DecodeContext(ctx, &actual); err != nil { - return fmt.Errorf("unable to parse as %q: %w", kind, err) - } - (*c)[name] = actual - default: - return fmt.Errorf("%q is not a valid kind of auth source", kind) - } - } - return nil -} - // EmbeddingModelConfigs is a type used to allow unmarshal of the embedding model config map type EmbeddingModelConfigs map[string]embeddingmodels.EmbeddingModelConfig -// validate interface -var _ yaml.InterfaceUnmarshalerContext = &EmbeddingModelConfigs{} - -func (c *EmbeddingModelConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error { - *c = make(EmbeddingModelConfigs) - // Parse the 'kind' fields for each embedding model - var raw map[string]util.DelayedUnmarshaler - if err := unmarshal(&raw); err != nil { - return err - } - - for name, u := range raw { - // Unmarshal to a general type that ensure it capture all fields - var v map[string]any - if err := u.Unmarshal(&v); err != nil { - return fmt.Errorf("unable to unmarshal embedding model %q: %w", name, err) - } - - kind, ok := v["kind"] - if !ok { - return fmt.Errorf("missing 'kind' field for embedding model %q", name) - } - - dec, err := util.NewStrictDecoder(v) - if err != nil { - return fmt.Errorf("error creating decoder: %w", err) - } - switch kind { - case gemini.EmbeddingModelType: - actual := gemini.Config{Name: name} - if err := dec.DecodeContext(ctx, &actual); err != nil { - return fmt.Errorf("unable to parse as %q: %w", kind, err) - } - (*c)[name] = actual - default: - return fmt.Errorf("%q is not a valid kind of auth source", kind) - } - } - return nil -} - // ToolConfigs is a type used to allow unmarshal of the tool configs type ToolConfigs map[string]tools.ToolConfig -// validate interface -var _ yaml.InterfaceUnmarshalerContext = &ToolConfigs{} - -func (c *ToolConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error { - *c = make(ToolConfigs) - // Parse the 'kind' fields for each source - var raw map[string]util.DelayedUnmarshaler - if err := unmarshal(&raw); err != nil { - return err - } - - for name, u := range raw { - err := NameValidation(name) - if err != nil { - return err - } - var v map[string]any - if err := u.Unmarshal(&v); err != nil { - return fmt.Errorf("unable to unmarshal %q: %w", name, err) - } - - // `authRequired` and `useClientOAuth` cannot be specified together - if v["authRequired"] != nil && v["useClientOAuth"] == true { - return fmt.Errorf("`authRequired` and `useClientOAuth` are mutually exclusive. Choose only one authentication method") - } - - // Make `authRequired` an empty list instead of nil for Tool manifest - if v["authRequired"] == nil { - v["authRequired"] = []string{} - } - - kindVal, ok := v["kind"] - if !ok { - return fmt.Errorf("missing 'kind' field for tool %q", name) - } - kindStr, ok := kindVal.(string) - if !ok { - return fmt.Errorf("invalid 'kind' field for tool %q (must be a string)", name) - } - - // validify parameter references - if rawParams, ok := v["parameters"]; ok { - if paramsList, ok := rawParams.([]any); ok { - // Turn params into a map - validParamNames := make(map[string]bool) - for _, rawP := range paramsList { - if pMap, ok := rawP.(map[string]any); ok { - if pName, ok := pMap["name"].(string); ok && pName != "" { - validParamNames[pName] = true - } - } - } - - // Validate references - for i, rawP := range paramsList { - pMap, ok := rawP.(map[string]any) - if !ok { - continue - } - - pName, _ := pMap["name"].(string) - refName, _ := pMap["valueFromParam"].(string) - - if refName != "" { - // Check if the referenced parameter exists - if !validParamNames[refName] { - return fmt.Errorf("tool %q config error: parameter %q (index %d) references '%q' in the 'valueFromParam' field, which is not a defined parameter", name, pName, i, refName) - } - - // Check for self-reference - if refName == pName { - return fmt.Errorf("tool %q config error: parameter %q cannot copy value from itself", name, pName) - } - } - } - } - } - yamlDecoder, err := util.NewStrictDecoder(v) - if err != nil { - return fmt.Errorf("error creating YAML decoder for tool %q: %w", name, err) - } - - toolCfg, err := tools.DecodeConfig(ctx, kindStr, name, yamlDecoder) - if err != nil { - return err - } - (*c)[name] = toolCfg - } - return nil -} - // ToolsetConfigs is a type used to allow unmarshal of the toolset configs type ToolsetConfigs map[string]tools.ToolsetConfig -// validate interface -var _ yaml.InterfaceUnmarshalerContext = &ToolsetConfigs{} - -func (c *ToolsetConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error { - *c = make(ToolsetConfigs) - - var raw map[string][]string - if err := unmarshal(&raw); err != nil { - return err - } - - for name, toolList := range raw { - (*c)[name] = tools.ToolsetConfig{Name: name, ToolNames: toolList} - } - return nil -} - // PromptConfigs is a type used to allow unmarshal of the prompt configs type PromptConfigs map[string]prompts.PromptConfig -// validate interface -var _ yaml.InterfaceUnmarshalerContext = &PromptConfigs{} - -func (c *PromptConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error { - *c = make(PromptConfigs) - var raw map[string]util.DelayedUnmarshaler - if err := unmarshal(&raw); err != nil { - return err - } - - for name, u := range raw { - var v map[string]any - if err := u.Unmarshal(&v); err != nil { - return fmt.Errorf("unable to unmarshal prompt %q: %w", name, err) - } - - // Look for the 'kind' field. If it's not present, kindStr will be an - // empty string, which prompts.DecodeConfig will correctly default to "custom". - var kindStr string - if kindVal, ok := v["kind"]; ok { - var isString bool - kindStr, isString = kindVal.(string) - if !isString { - return fmt.Errorf("invalid 'kind' field for prompt %q (must be a string)", name) - } - } - - // Create a new, strict decoder for this specific prompt's data. - yamlDecoder, err := util.NewStrictDecoder(v) - if err != nil { - return fmt.Errorf("error creating YAML decoder for prompt %q: %w", name, err) - } - - // Use the central registry to decode the prompt based on its kind. - promptCfg, err := prompts.DecodeConfig(ctx, kindStr, name, yamlDecoder) - if err != nil { - return err - } - (*c)[name] = promptCfg - } - return nil -} - -// PromptsetConfigs is a type used to allow unmarshal of the PromptsetConfigs configs +// PromptConfigs is a type used to allow unmarshal of the prompt configs type PromptsetConfigs map[string]prompts.PromptsetConfig -// validate interface -var _ yaml.InterfaceUnmarshalerContext = &PromptsetConfigs{} +func UnmarshalResourceConfig(ctx context.Context, raw []byte) (SourceConfigs, AuthServiceConfigs, EmbeddingModelConfigs, ToolConfigs, ToolsetConfigs, PromptConfigs, error) { + // prepare configs map + sourceConfigs := make(map[string]sources.SourceConfig) + authServiceConfigs := make(AuthServiceConfigs) + embeddingModelConfigs := make(EmbeddingModelConfigs) + toolConfigs := make(ToolConfigs) + toolsetConfigs := make(ToolsetConfigs) + promptConfigs := make(PromptConfigs) + // promptset configs is not yet supported -func (c *PromptsetConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error { - *c = make(PromptsetConfigs) + decoder := yaml.NewDecoder(bytes.NewReader(raw)) + // for loop to unmarshal documents with the `---` separator + for { + var resource map[string]any + if err := decoder.DecodeContext(ctx, &resource); err != nil { + if err == io.EOF { + break + } + return nil, nil, nil, nil, nil, nil, fmt.Errorf("unable to decode YAML document: %w", err) + } + var kind, name string + var ok bool + if kind, ok = resource["kind"].(string); !ok { + return nil, nil, nil, nil, nil, nil, fmt.Errorf("missing 'kind' field or it is not a string: %v", resource) + } + if name, ok = resource["name"].(string); !ok { + return nil, nil, nil, nil, nil, nil, fmt.Errorf("missing 'name' field or it is not a string") + } + // remove 'kind' from map for strict unmarshaling + delete(resource, "kind") + switch kind { + case "sources": + c, err := UnmarshalYAMLSourceConfig(ctx, name, resource) + if err != nil { + return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err) + } + sourceConfigs[name] = c + case "authServices": + c, err := UnmarshalYAMLAuthServiceConfig(ctx, name, resource) + if err != nil { + return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err) + } + authServiceConfigs[name] = c + case "tools": + c, err := UnmarshalYAMLToolConfig(ctx, name, resource) + if err != nil { + return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err) + } + toolConfigs[name] = c + case "toolsets": + c, err := UnmarshalYAMLToolsetConfig(ctx, name, resource) + if err != nil { + return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err) + } + toolsetConfigs[name] = c + case "embeddingModels": + c, err := UnmarshalYAMLEmbeddingModelConfig(ctx, name, resource) + if err != nil { + return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err) + } + embeddingModelConfigs[name] = c + case "prompts": + c, err := UnmarshalYAMLPromptConfig(ctx, name, resource) + if err != nil { + return nil, nil, nil, nil, nil, nil, fmt.Errorf("error unmarshaling %s: %s", kind, err) + } + promptConfigs[name] = c + default: + return nil, nil, nil, nil, nil, nil, fmt.Errorf("invalid kind %s", kind) + } + } + return sourceConfigs, authServiceConfigs, embeddingModelConfigs, toolConfigs, toolsetConfigs, promptConfigs, nil +} + +func UnmarshalYAMLSourceConfig(ctx context.Context, name string, r map[string]any) (sources.SourceConfig, error) { + typeStr, ok := r["type"].(string) + if !ok { + return nil, fmt.Errorf("missing 'type' field or it is not a string") + } + dec, err := util.NewStrictDecoder(r) + if err != nil { + return nil, fmt.Errorf("error creating decoder: %w", err) + } + sourceConfig, err := sources.DecodeConfig(ctx, typeStr, name, dec) + if err != nil { + return nil, err + } + return sourceConfig, nil +} + +func UnmarshalYAMLAuthServiceConfig(ctx context.Context, name string, r map[string]any) (auth.AuthServiceConfig, error) { + typeStr, ok := r["type"].(string) + if !ok { + return nil, fmt.Errorf("missing 'type' field or it is not a string") + } + if typeStr != google.AuthServiceType { + return nil, fmt.Errorf("%s is not a valid type of auth service", typeStr) + } + dec, err := util.NewStrictDecoder(r) + if err != nil { + return nil, fmt.Errorf("error creating decoder: %s", err) + } + actual := google.Config{Name: name} + if err := dec.DecodeContext(ctx, &actual); err != nil { + return nil, fmt.Errorf("unable to parse as %s: %w", name, err) + } + return actual, nil +} + +func UnmarshalYAMLEmbeddingModelConfig(ctx context.Context, name string, r map[string]any) (embeddingmodels.EmbeddingModelConfig, error) { + typeStr, ok := r["type"].(string) + if !ok { + return nil, fmt.Errorf("missing 'type' field or it is not a string") + } + if typeStr != gemini.EmbeddingModelType { + return nil, fmt.Errorf("%s is not a valid type of embedding model", typeStr) + } + dec, err := util.NewStrictDecoder(r) + if err != nil { + return nil, fmt.Errorf("error creating decoder: %s", err) + } + actual := gemini.Config{Name: name} + if err := dec.DecodeContext(ctx, &actual); err != nil { + return nil, fmt.Errorf("unable to parse as %q: %w", name, err) + } + return actual, nil +} + +func UnmarshalYAMLToolConfig(ctx context.Context, name string, r map[string]any) (tools.ToolConfig, error) { + typeStr, ok := r["type"].(string) + if !ok { + return nil, fmt.Errorf("missing 'type' field or it is not a string") + } + // `authRequired` and `useClientOAuth` cannot be specified together + if r["authRequired"] != nil && r["useClientOAuth"] == true { + return nil, fmt.Errorf("`authRequired` and `useClientOAuth` are mutually exclusive. Choose only one authentication method") + } + // Make `authRequired` an empty list instead of nil for Tool manifest + if r["authRequired"] == nil { + r["authRequired"] = []string{} + } + + // validify parameter references + if rawParams, ok := r["parameters"]; ok { + if paramsList, ok := rawParams.([]any); ok { + // Turn params into a map + validParamNames := make(map[string]bool) + for _, rawP := range paramsList { + if pMap, ok := rawP.(map[string]any); ok { + if pName, ok := pMap["name"].(string); ok && pName != "" { + validParamNames[pName] = true + } + } + } + + // Validate references + for i, rawP := range paramsList { + pMap, ok := rawP.(map[string]any) + if !ok { + continue + } + + pName, _ := pMap["name"].(string) + refName, _ := pMap["valueFromParam"].(string) + + if refName != "" { + // Check if the referenced parameter exists + if !validParamNames[refName] { + return nil, fmt.Errorf("tool %q config error: parameter %q (index %d) references '%q' in the 'valueFromParam' field, which is not a defined parameter", name, pName, i, refName) + } + + // Check for self-reference + if refName == pName { + return nil, fmt.Errorf("tool %q config error: parameter %q cannot copy value from itself", name, pName) + } + } + } + } + } + + dec, err := util.NewStrictDecoder(r) + if err != nil { + return nil, fmt.Errorf("error creating decoder: %s", err) + } + toolCfg, err := tools.DecodeConfig(ctx, typeStr, name, dec) + if err != nil { + return nil, err + } + return toolCfg, nil +} + +func UnmarshalYAMLToolsetConfig(ctx context.Context, name string, r map[string]any) (tools.ToolsetConfig, error) { + var toolsetConfig tools.ToolsetConfig + toolList, ok := r["tools"].([]string) + if !ok { + return toolsetConfig, fmt.Errorf("tools is missing or not a list of strings: %v", r) + } + justTools := map[string]any{"tools": toolList} + dec, err := util.NewStrictDecoder(justTools) + if err != nil { + return toolsetConfig, fmt.Errorf("error creating decoder: %s", err) + } var raw map[string][]string - if err := unmarshal(&raw); err != nil { - return err + if err := dec.DecodeContext(ctx, &raw); err != nil { + return toolsetConfig, fmt.Errorf("unable to unmarshal tools: %s", err) + } + return tools.ToolsetConfig{Name: name, ToolNames: raw["tools"]}, nil +} + +func UnmarshalYAMLPromptConfig(ctx context.Context, name string, r map[string]any) (prompts.PromptConfig, error) { + // Look for the 'kind' field. If it's not present, kindStr will be an + // empty string, which prompts.DecodeConfig will correctly default to "custom". + var typeStr string + if typeVal, ok := r["type"]; ok { + var isString bool + typeStr, isString = typeVal.(string) + if !isString { + return nil, fmt.Errorf("invalid 'type' field for prompt %q (must be a string)", name) + } + } + dec, err := util.NewStrictDecoder(r) + if err != nil { + return nil, fmt.Errorf("error creating decoder: %s", err) } - for name, promptList := range raw { - (*c)[name] = prompts.PromptsetConfig{Name: name, PromptNames: promptList} + // Use the central registry to decode the prompt based on its kind. + promptCfg, err := prompts.DecodeConfig(ctx, typeStr, name, dec) + if err != nil { + return nil, err } - return nil + return promptCfg, nil } // Tools naming validation is added in the MCP v2025-11-25, but we'll be