mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-04 04:05:22 -05:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d903aafd31 | ||
|
|
3f1908a822 | ||
|
|
eef7a94977 |
@@ -35,6 +35,7 @@ import (
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/cli/invoke"
|
||||
"github.com/googleapis/genai-toolbox/internal/cli/skills"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
||||
@@ -401,6 +402,8 @@ func NewCommand(opts ...Option) *Command {
|
||||
|
||||
// Register subcommands for tool invocation
|
||||
baseCmd.AddCommand(invoke.NewCommand(cmd))
|
||||
// Register subcommands for skill generation
|
||||
baseCmd.AddCommand(skills.NewCommand(cmd))
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
179
cmd/skill_generate_test.go
Normal file
179
cmd/skill_generate_test.go
Normal file
@@ -0,0 +1,179 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGenerateSkill(t *testing.T) {
|
||||
// Create a temporary directory for tests
|
||||
tmpDir := t.TempDir()
|
||||
outputDir := filepath.Join(tmpDir, "skills")
|
||||
|
||||
// Create a tools.yaml file with a sqlite tool
|
||||
toolsFileContent := `
|
||||
sources:
|
||||
my-sqlite:
|
||||
kind: sqlite
|
||||
database: test.db
|
||||
tools:
|
||||
hello-sqlite:
|
||||
kind: sqlite-sql
|
||||
source: my-sqlite
|
||||
description: "hello tool"
|
||||
statement: "SELECT 'hello' as greeting"
|
||||
`
|
||||
|
||||
toolsFilePath := filepath.Join(tmpDir, "tools.yaml")
|
||||
if err := os.WriteFile(toolsFilePath, []byte(toolsFileContent), 0644); err != nil {
|
||||
t.Fatalf("failed to write tools file: %v", err)
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"skills-generate",
|
||||
"--tools-file", toolsFilePath,
|
||||
"--output-dir", outputDir,
|
||||
"--name", "hello-sqlite",
|
||||
"--description", "hello tool",
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, got, err := invokeCommandWithContext(ctx, args)
|
||||
if err != nil {
|
||||
t.Fatalf("command failed: %v\nOutput: %s", err, got)
|
||||
}
|
||||
|
||||
// Verify generated directory structure
|
||||
skillPath := filepath.Join(outputDir, "hello-sqlite")
|
||||
if _, err := os.Stat(skillPath); os.IsNotExist(err) {
|
||||
t.Fatalf("skill directory not created: %s", skillPath)
|
||||
}
|
||||
|
||||
// Check SKILL.md
|
||||
skillMarkdown := filepath.Join(skillPath, "SKILL.md")
|
||||
content, err := os.ReadFile(skillMarkdown)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read SKILL.md: %v", err)
|
||||
}
|
||||
|
||||
expectedFrontmatter := `---
|
||||
name: hello-sqlite
|
||||
description: hello tool
|
||||
---`
|
||||
if !strings.HasPrefix(string(content), expectedFrontmatter) {
|
||||
t.Errorf("SKILL.md does not have expected frontmatter format.\nExpected prefix:\n%s\nGot:\n%s", expectedFrontmatter, string(content))
|
||||
}
|
||||
|
||||
if !strings.Contains(string(content), "## Usage") {
|
||||
t.Errorf("SKILL.md does not contain '## Usage' section")
|
||||
}
|
||||
|
||||
if !strings.Contains(string(content), "## Scripts") {
|
||||
t.Errorf("SKILL.md does not contain '## Scripts' section")
|
||||
}
|
||||
|
||||
if !strings.Contains(string(content), "### hello-sqlite") {
|
||||
t.Errorf("SKILL.md does not contain '### hello-sqlite' tool header")
|
||||
}
|
||||
|
||||
// Check script file
|
||||
scriptFilename := "hello-sqlite.js"
|
||||
scriptPath := filepath.Join(skillPath, "scripts", scriptFilename)
|
||||
if _, err := os.Stat(scriptPath); os.IsNotExist(err) {
|
||||
t.Fatalf("script file not created: %s", scriptPath)
|
||||
}
|
||||
|
||||
scriptContent, err := os.ReadFile(scriptPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read script file: %v", err)
|
||||
}
|
||||
if !strings.Contains(string(scriptContent), "hello-sqlite") {
|
||||
t.Errorf("script file does not contain expected tool name")
|
||||
}
|
||||
|
||||
// Check assets
|
||||
assetPath := filepath.Join(skillPath, "assets", "hello-sqlite.yaml")
|
||||
if _, err := os.Stat(assetPath); os.IsNotExist(err) {
|
||||
t.Fatalf("asset file not created: %s", assetPath)
|
||||
}
|
||||
assetContent, err := os.ReadFile(assetPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read asset file: %v", err)
|
||||
}
|
||||
if !strings.Contains(string(assetContent), "hello-sqlite") {
|
||||
t.Errorf("asset file does not contain expected tool name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSkill_NoConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
outputDir := filepath.Join(tmpDir, "skills")
|
||||
|
||||
args := []string{
|
||||
"skills-generate",
|
||||
"--output-dir", outputDir,
|
||||
"--name", "test",
|
||||
"--description", "test",
|
||||
}
|
||||
|
||||
_, _, err := invokeCommandWithContext(context.Background(), args)
|
||||
if err == nil {
|
||||
t.Fatal("expected command to fail when no configuration is provided and tools.yaml is missing")
|
||||
}
|
||||
|
||||
// Should not have created the directory if no config was processed
|
||||
if _, err := os.Stat(outputDir); !os.IsNotExist(err) {
|
||||
t.Errorf("output directory should not have been created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSkill_MissingArguments(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
toolsFilePath := filepath.Join(tmpDir, "tools.yaml")
|
||||
if err := os.WriteFile(toolsFilePath, []byte("tools: {}"), 0644); err != nil {
|
||||
t.Fatalf("failed to write tools file: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
}{
|
||||
{
|
||||
name: "missing name",
|
||||
args: []string{"skills-generate", "--tools-file", toolsFilePath, "--description", "test"},
|
||||
},
|
||||
{
|
||||
name: "missing description",
|
||||
args: []string{"skills-generate", "--tools-file", toolsFilePath, "--name", "test"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, got, err := invokeCommandWithContext(context.Background(), tt.args)
|
||||
if err == nil {
|
||||
t.Fatalf("expected command to fail due to missing arguments, but it succeeded\nOutput: %s", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -53,7 +53,7 @@ export async function main() {
|
||||
|
||||
for (const query of queries) {
|
||||
conversationHistory.push({ role: "user", content: [{ text: query }] });
|
||||
const response = await ai.generate({
|
||||
let response = await ai.generate({
|
||||
messages: conversationHistory,
|
||||
tools: tools,
|
||||
});
|
||||
|
||||
112
docs/en/how-to/generate_skill.md
Normal file
112
docs/en/how-to/generate_skill.md
Normal file
@@ -0,0 +1,112 @@
|
||||
---
|
||||
title: "Generate Agent Skills"
|
||||
type: docs
|
||||
weight: 10
|
||||
description: >
|
||||
How to generate agent skills from a toolset.
|
||||
---
|
||||
|
||||
The `skills-generate` command allows you to convert a **toolset** into an **Agent Skill**. A toolset is a collection of tools, and the generated skill will contain metadata and execution scripts for all tools within that toolset, complying with the [Agent Skill specification](https://agentskills.io/specification).
|
||||
|
||||
## Before you begin
|
||||
|
||||
1. Make sure you have the `toolbox` executable in your PATH.
|
||||
2. Make sure you have [Node.js](https://nodejs.org/) installed on your system.
|
||||
|
||||
## Generating a Skill from a Toolset
|
||||
|
||||
A skill package consists of a `SKILL.md` file (with required YAML frontmatter) and a set of Node.js scripts. Each tool defined in your toolset maps to a corresponding script in the generated Node.js scripts (`.js`) that work across different platforms (Linux, macOS, Windows).
|
||||
|
||||
|
||||
### Command Usage
|
||||
|
||||
The basic syntax for the command is:
|
||||
|
||||
```bash
|
||||
toolbox <tool-source> skills-generate \
|
||||
--name <skill-name> \
|
||||
--toolset <toolset-name> \
|
||||
--description <description> \
|
||||
--output-dir <output-directory>
|
||||
```
|
||||
|
||||
- `<tool-source>`: Can be `--tools-file`, `--tools-files`, `--tools-folder`, and `--prebuilt`. See the [CLI Reference](../reference/cli.md) for details.
|
||||
- `--name`: Name of the generated skill.
|
||||
- `--description`: Description of the generated skill.
|
||||
- `--toolset`: (Optional) Name of the toolset to convert into a skill. If not provided, all tools will be included.
|
||||
- `--output-dir`: (Optional) Directory to output generated skills (default: "skills").
|
||||
|
||||
{{< notice note >}}
|
||||
**Note:** The `<skill-name>` must follow the Agent Skill [naming convention](https://agentskills.io/specification): it must contain only lowercase alphanumeric characters and hyphens, cannot start or end with a hyphen, and cannot contain consecutive hyphens (e.g., `my-skill`, `data-processing`).
|
||||
{{< /notice >}}
|
||||
|
||||
### Example: Custom Tools File
|
||||
|
||||
1. Create a `tools.yaml` file with a toolset and some tools:
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
tool_a:
|
||||
description: "First tool"
|
||||
run:
|
||||
command: "echo 'Tool A'"
|
||||
tool_b:
|
||||
description: "Second tool"
|
||||
run:
|
||||
command: "echo 'Tool B'"
|
||||
toolsets:
|
||||
my_toolset:
|
||||
tools:
|
||||
- tool_a
|
||||
- tool_b
|
||||
```
|
||||
|
||||
2. Generate the skill:
|
||||
|
||||
```bash
|
||||
toolbox --tools-file tools.yaml skills-generate \
|
||||
--name "my-skill" \
|
||||
--toolset "my_toolset" \
|
||||
--description "A skill containing multiple tools" \
|
||||
--output-dir "generated-skills"
|
||||
```
|
||||
|
||||
3. The generated skill directory structure:
|
||||
|
||||
```text
|
||||
generated-skills/
|
||||
└── my-skill/
|
||||
├── SKILL.md
|
||||
├── assets/
|
||||
│ ├── tool_a.yaml
|
||||
│ └── tool_b.yaml
|
||||
└── scripts/
|
||||
├── tool_a.js
|
||||
└── tool_b.js
|
||||
```
|
||||
|
||||
In this example, the skill contains two Node.js scripts (`tool_a.js` and `tool_b.js`), each mapping to a tool in the original toolset.
|
||||
|
||||
### Example: Prebuilt Configuration
|
||||
|
||||
You can also generate skills from prebuilt toolsets:
|
||||
|
||||
```bash
|
||||
toolbox --prebuilt alloydb-postgres-admin skills-generate \
|
||||
--name "alloydb-postgres-admin" \
|
||||
--description "skill for performing administrative operations on alloydb"
|
||||
```
|
||||
|
||||
## Installing the Generated Skill in Gemini CLI
|
||||
|
||||
Once you have generated a skill, you can install it into the Gemini CLI using the `gemini skills install` command.
|
||||
|
||||
### Installation Command
|
||||
|
||||
Provide the path to the directory containing the generated skill:
|
||||
|
||||
```bash
|
||||
gemini skills install /path/to/generated-skills/my-skill
|
||||
```
|
||||
|
||||
Alternatively, use ~/.gemini/skills as the `--output-dir` to generate the skill straight to the Gemini CLI.
|
||||
@@ -13,21 +13,22 @@ The `invoke` command allows you to invoke tools defined in your configuration di
|
||||
|
||||
{{< notice tip >}}
|
||||
**Keep configurations minimal:** The `invoke` command initializes *all* resources (sources, tools, etc.) defined in your configuration files during execution. To ensure fast response times, consider using a minimal configuration file containing only the tools you need for the specific invocation.
|
||||
{{< notice tip >}}
|
||||
{{< /notice >}}
|
||||
|
||||
## Prerequisites
|
||||
## Before you begin
|
||||
|
||||
- You have the `toolbox` binary installed or built.
|
||||
- You have a valid tool configuration file (e.g., `tools.yaml`).
|
||||
1. Make sure you have the `toolbox` binary installed or built.
|
||||
2. Make sure you have a valid tool configuration file (e.g., `tools.yaml`).
|
||||
|
||||
## Basic Usage
|
||||
### Command Usage
|
||||
|
||||
The basic syntax for the command is:
|
||||
|
||||
```bash
|
||||
toolbox [--tools-file <path> | --prebuilt <name>] invoke <tool-name> [params]
|
||||
toolbox <tool-source> invoke <tool-name> [params]
|
||||
```
|
||||
|
||||
- `<tool-source>`: Can be `--tools-file`, `--tools-files`, `--tools-folder`, and `--prebuilt`. See the [CLI Reference](../reference/cli.md) for details.
|
||||
- `<tool-name>`: The name of the tool you want to call. This must match the name defined in your `tools.yaml`.
|
||||
- `[params]`: (Optional) A JSON string representing the arguments for the tool.
|
||||
|
||||
|
||||
@@ -32,7 +32,8 @@ description: >
|
||||
|
||||
## Sub Commands
|
||||
|
||||
### `invoke`
|
||||
<details>
|
||||
<summary><code>invoke</code></summary>
|
||||
|
||||
Executes a tool directly with the provided parameters. This is useful for testing tool configurations and parameters without needing a full client setup.
|
||||
|
||||
@@ -42,8 +43,36 @@ Executes a tool directly with the provided parameters. This is useful for testin
|
||||
toolbox invoke <tool-name> [params]
|
||||
```
|
||||
|
||||
- `<tool-name>`: The name of the tool to execute (as defined in your configuration).
|
||||
- `[params]`: (Optional) A JSON string containing the parameters for the tool.
|
||||
**Arguments:**
|
||||
|
||||
- `tool-name`: The name of the tool to execute (as defined in your configuration).
|
||||
- `params`: (Optional) A JSON string containing the parameters for the tool.
|
||||
|
||||
For more detailed instructions, see [Invoke Tools via CLI](../how-to/invoke_tool.md).
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><code>skills-generate</code></summary>
|
||||
|
||||
Generates a skill package from a specified toolset. Each tool in the toolset will have a corresponding Node.js execution script in the generated skill.
|
||||
|
||||
**Syntax:**
|
||||
|
||||
```bash
|
||||
toolbox skills-generate --name <name> --description <description> --toolset <toolset> --output-dir <output>
|
||||
```
|
||||
|
||||
**Flags:**
|
||||
|
||||
- `--name`: Name of the generated skill.
|
||||
- `--description`: Description of the generated skill.
|
||||
- `--toolset`: (Optional) Name of the toolset to convert into a skill. If not provided, all tools will be included.
|
||||
- `--output-dir`: (Optional) Directory to output generated skills (default: "skills").
|
||||
|
||||
For more detailed instructions, see [Generate Agent Skills](../how-to/generate_skill.md).
|
||||
|
||||
</details>
|
||||
|
||||
## Examples
|
||||
|
||||
|
||||
237
internal/cli/skills/command.go
Normal file
237
internal/cli/skills/command.go
Normal file
@@ -0,0 +1,237 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package skills
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// RootCommand defines the interface for required by skills-generate subcommand.
|
||||
// This allows subcommands to access shared resources and functionality without
|
||||
// direct coupling to the root command's implementation.
|
||||
type RootCommand interface {
|
||||
// Config returns a copy of the current server configuration.
|
||||
Config() server.ServerConfig
|
||||
|
||||
// LoadConfig loads and merges the configuration from files, folders, and prebuilts.
|
||||
LoadConfig(ctx context.Context) error
|
||||
|
||||
// Setup initializes the runtime environment, including logging and telemetry.
|
||||
// It returns the updated context and a shutdown function to be called when finished.
|
||||
Setup(ctx context.Context) (context.Context, func(context.Context) error, error)
|
||||
|
||||
// Logger returns the logger instance.
|
||||
Logger() log.Logger
|
||||
}
|
||||
|
||||
// Command is the command for generating skills.
|
||||
type Command struct {
|
||||
*cobra.Command
|
||||
rootCmd RootCommand
|
||||
name string
|
||||
description string
|
||||
toolset string
|
||||
outputDir string
|
||||
}
|
||||
|
||||
// NewCommand creates a new Command.
|
||||
func NewCommand(rootCmd RootCommand) *cobra.Command {
|
||||
cmd := &Command{
|
||||
rootCmd: rootCmd,
|
||||
}
|
||||
cmd.Command = &cobra.Command{
|
||||
Use: "skills-generate",
|
||||
Short: "Generate skills from tool configurations",
|
||||
RunE: func(c *cobra.Command, args []string) error {
|
||||
return cmd.run(c)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&cmd.name, "name", "", "Name of the generated skill.")
|
||||
cmd.Flags().StringVar(&cmd.description, "description", "", "Description of the generated skill")
|
||||
cmd.Flags().StringVar(&cmd.toolset, "toolset", "", "Name of the toolset to convert into a skill. If not provided, all tools will be included.")
|
||||
cmd.Flags().StringVar(&cmd.outputDir, "output-dir", "skills", "Directory to output generated skills")
|
||||
|
||||
_ = cmd.MarkFlagRequired("name")
|
||||
_ = cmd.MarkFlagRequired("description")
|
||||
return cmd.Command
|
||||
}
|
||||
|
||||
func (c *Command) run(cmd *cobra.Command) error {
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
ctx, shutdown, err := c.rootCmd.Setup(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = shutdown(ctx)
|
||||
}()
|
||||
|
||||
logger := c.rootCmd.Logger()
|
||||
|
||||
// Load and merge tool configurations
|
||||
if err := c.rootCmd.LoadConfig(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(c.outputDir, 0755); err != nil {
|
||||
errMsg := fmt.Errorf("error creating output directory: %w", err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
logger.InfoContext(ctx, fmt.Sprintf("Generating skill '%s'...", c.name))
|
||||
|
||||
// Initialize toolbox and collect tools
|
||||
allTools, err := c.collectTools(ctx)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error collecting tools: %w", err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
if len(allTools) == 0 {
|
||||
logger.InfoContext(ctx, "No tools found to generate.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate the combined skill directory
|
||||
skillPath := filepath.Join(c.outputDir, c.name)
|
||||
if err := os.MkdirAll(skillPath, 0755); err != nil {
|
||||
errMsg := fmt.Errorf("error creating skill directory: %w", err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
// Generate assets directory
|
||||
assetsPath := filepath.Join(skillPath, "assets")
|
||||
if err := os.MkdirAll(assetsPath, 0755); err != nil {
|
||||
errMsg := fmt.Errorf("error creating assets dir: %w", err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
// Generate scripts directory
|
||||
scriptsPath := filepath.Join(skillPath, "scripts")
|
||||
if err := os.MkdirAll(scriptsPath, 0755); err != nil {
|
||||
errMsg := fmt.Errorf("error creating scripts dir: %w", err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
// Iterate over keys to ensure deterministic order
|
||||
var toolNames []string
|
||||
for name := range allTools {
|
||||
toolNames = append(toolNames, name)
|
||||
}
|
||||
sort.Strings(toolNames)
|
||||
|
||||
for _, toolName := range toolNames {
|
||||
// Generate YAML config in asset directory
|
||||
minimizedContent, err := generateToolConfigYAML(c.rootCmd.Config(), toolName)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error generating filtered config for %s: %w", toolName, err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
specificToolsFileName := fmt.Sprintf("%s.yaml", toolName)
|
||||
if minimizedContent != nil {
|
||||
destPath := filepath.Join(assetsPath, specificToolsFileName)
|
||||
if err := os.WriteFile(destPath, minimizedContent, 0644); err != nil {
|
||||
errMsg := fmt.Errorf("error writing filtered config for %s: %w", toolName, err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
}
|
||||
|
||||
// Generate wrapper script in scripts directory
|
||||
scriptContent, err := generateScriptContent(toolName, specificToolsFileName)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error generating script content for %s: %w", toolName, err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
scriptFilename := filepath.Join(scriptsPath, fmt.Sprintf("%s.js", toolName))
|
||||
if err := os.WriteFile(scriptFilename, []byte(scriptContent), 0755); err != nil {
|
||||
errMsg := fmt.Errorf("error writing script %s: %w", scriptFilename, err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
}
|
||||
|
||||
// Generate SKILL.md
|
||||
skillContent, err := generateSkillMarkdown(c.name, c.description, allTools)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error generating SKILL.md content: %w", err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
skillMdPath := filepath.Join(skillPath, "SKILL.md")
|
||||
if err := os.WriteFile(skillMdPath, []byte(skillContent), 0644); err != nil {
|
||||
errMsg := fmt.Errorf("error writing SKILL.md: %w", err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
logger.InfoContext(ctx, fmt.Sprintf("Successfully generated skill '%s' with %d tools.", c.name, len(allTools)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Command) collectTools(ctx context.Context) (map[string]tools.Tool, error) {
|
||||
// Initialize Resources
|
||||
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, c.rootCmd.Config())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize resources: %w", err)
|
||||
}
|
||||
|
||||
resourceMgr := resources.NewResourceManager(sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
|
||||
|
||||
result := make(map[string]tools.Tool)
|
||||
|
||||
if c.toolset == "" {
|
||||
return toolsMap, nil
|
||||
}
|
||||
|
||||
ts, ok := resourceMgr.GetToolset(c.toolset)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("toolset %q not found", c.toolset)
|
||||
}
|
||||
|
||||
for _, t := range ts.Tools {
|
||||
if t != nil {
|
||||
tool := *t
|
||||
result[tool.McpManifest().Name] = tool
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
297
internal/cli/skills/generator.go
Normal file
297
internal/cli/skills/generator.go
Normal file
@@ -0,0 +1,297 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package skills
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
const skillTemplate = `---
|
||||
name: {{.SkillName}}
|
||||
description: {{.SkillDescription}}
|
||||
---
|
||||
|
||||
## Usage
|
||||
|
||||
All scripts can be executed using Node.js. Replace ` + "`" + `<param_name>` + "`" + ` and ` + "`" + `<param_value>` + "`" + ` with actual values.
|
||||
|
||||
**Bash:**
|
||||
` + "`" + `node scripts/<script_name>.js '{"<param_name>": "<param_value>"}'` + "`" + `
|
||||
|
||||
**PowerShell:**
|
||||
` + "`" + `node scripts/<script_name>.js '{\"<param_name>\": \"<param_value>\"}'` + "`" + `
|
||||
|
||||
## Scripts
|
||||
|
||||
{{range .Tools}}
|
||||
### {{.Name}}
|
||||
|
||||
{{.Description}}
|
||||
|
||||
{{.ParametersSchema}}
|
||||
|
||||
---
|
||||
{{end}}
|
||||
`
|
||||
|
||||
type toolTemplateData struct {
|
||||
Name string
|
||||
Description string
|
||||
ParametersSchema string
|
||||
}
|
||||
|
||||
type skillTemplateData struct {
|
||||
SkillName string
|
||||
SkillDescription string
|
||||
Tools []toolTemplateData
|
||||
}
|
||||
|
||||
// generateSkillMarkdown generates the content of the SKILL.md file.
|
||||
// It includes usage instructions and a reference section for each tool in the skill,
|
||||
// detailing its description and parameters.
|
||||
func generateSkillMarkdown(skillName, skillDescription string, toolsMap map[string]tools.Tool) (string, error) {
|
||||
var toolsData []toolTemplateData
|
||||
|
||||
// Order tools based on name
|
||||
var toolNames []string
|
||||
for name := range toolsMap {
|
||||
toolNames = append(toolNames, name)
|
||||
}
|
||||
sort.Strings(toolNames)
|
||||
|
||||
for _, name := range toolNames {
|
||||
tool := toolsMap[name]
|
||||
manifest := tool.Manifest()
|
||||
|
||||
parametersSchema, err := formatParameters(manifest.Parameters)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
toolsData = append(toolsData, toolTemplateData{
|
||||
Name: name,
|
||||
Description: manifest.Description,
|
||||
ParametersSchema: parametersSchema,
|
||||
})
|
||||
}
|
||||
|
||||
data := skillTemplateData{
|
||||
SkillName: skillName,
|
||||
SkillDescription: skillDescription,
|
||||
Tools: toolsData,
|
||||
}
|
||||
|
||||
tmpl, err := template.New("markdown").Parse(skillTemplate)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error parsing markdown template: %w", err)
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
if err := tmpl.Execute(&buf, data); err != nil {
|
||||
return "", fmt.Errorf("error executing markdown template: %w", err)
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
const nodeScriptTemplate = `#!/usr/bin/env node
|
||||
|
||||
const { spawn, execSync } = require('child_process');
|
||||
const path = require('path');
|
||||
const fs = require('fs');
|
||||
|
||||
const toolName = "{{.Name}}";
|
||||
const toolsFileName = "{{.ToolsFileName}}";
|
||||
|
||||
function getToolboxPath() {
|
||||
try {
|
||||
const checkCommand = process.platform === 'win32' ? 'where toolbox' : 'which toolbox';
|
||||
const globalPath = execSync(checkCommand, { stdio: 'pipe', encoding: 'utf-8' }).trim();
|
||||
if (globalPath) {
|
||||
return globalPath.split('\n')[0].trim();
|
||||
}
|
||||
} catch (e) {
|
||||
// Ignore error;
|
||||
}
|
||||
const localPath = path.resolve(__dirname, '../../../toolbox');
|
||||
if (fs.existsSync(localPath)) {
|
||||
return localPath;
|
||||
}
|
||||
throw new Error("Toolbox binary not found");
|
||||
}
|
||||
|
||||
let toolboxBinary;
|
||||
try {
|
||||
toolboxBinary = getToolboxPath();
|
||||
} catch (err) {
|
||||
console.error("Error:", err.message);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
let configArgs = [];
|
||||
if (toolsFileName) {
|
||||
configArgs.push("--tools-file", path.join(__dirname, "..", "assets", toolsFileName));
|
||||
}
|
||||
|
||||
const args = process.argv.slice(2);
|
||||
const toolboxArgs = [...configArgs, "invoke", toolName, ...args];
|
||||
|
||||
const child = spawn(toolboxBinary, toolboxArgs, { stdio: 'inherit' });
|
||||
|
||||
child.on('close', (code) => {
|
||||
process.exit(code);
|
||||
});
|
||||
|
||||
child.on('error', (err) => {
|
||||
console.error("Error executing toolbox:", err);
|
||||
process.exit(1);
|
||||
});
|
||||
`
|
||||
|
||||
type scriptData struct {
|
||||
Name string
|
||||
ToolsFileName string
|
||||
}
|
||||
|
||||
// generateScriptContent creates the content for a Node.js wrapper script.
|
||||
// This script invokes the toolbox CLI with the appropriate configuration
|
||||
// (using a generated tools file) and arguments to execute the specific tool.
|
||||
func generateScriptContent(name string, toolsFileName string) (string, error) {
|
||||
data := scriptData{
|
||||
Name: name,
|
||||
ToolsFileName: toolsFileName,
|
||||
}
|
||||
|
||||
tmpl, err := template.New("script").Parse(nodeScriptTemplate)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error parsing script template: %w", err)
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
if err := tmpl.Execute(&buf, data); err != nil {
|
||||
return "", fmt.Errorf("error executing script template: %w", err)
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// formatParameters converts a list of parameter manifests into a formatted JSON schema string.
|
||||
// This schema is used in the skill documentation to describe the input parameters for a tool.
|
||||
func formatParameters(params []parameters.ParameterManifest) (string, error) {
|
||||
if len(params) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
properties := make(map[string]interface{})
|
||||
var required []string
|
||||
|
||||
for _, p := range params {
|
||||
paramMap := map[string]interface{}{
|
||||
"type": p.Type,
|
||||
"description": p.Description,
|
||||
}
|
||||
if p.Default != nil {
|
||||
paramMap["default"] = p.Default
|
||||
}
|
||||
properties[p.Name] = paramMap
|
||||
if p.Required {
|
||||
required = append(required, p.Name)
|
||||
}
|
||||
}
|
||||
|
||||
schema := map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
}
|
||||
if len(required) > 0 {
|
||||
schema["required"] = required
|
||||
}
|
||||
|
||||
schemaJSON, err := json.MarshalIndent(schema, "", " ")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error generating parameters schema: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("#### Parameters\n\n```json\n%s\n```", string(schemaJSON)), nil
|
||||
}
|
||||
|
||||
// generateToolConfigYAML generates the YAML configuration for a single tool and its dependency (source).
|
||||
// It extracts the relevant tool and source configurations from the server config and formats them
|
||||
// into a YAML document suitable for inclusion in the skill's assets.
|
||||
func generateToolConfigYAML(cfg server.ServerConfig, toolName string) ([]byte, error) {
|
||||
toolCfg, ok := cfg.ToolConfigs[toolName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("error finding tool config: %s", toolName)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
encoder := yaml.NewEncoder(&buf)
|
||||
|
||||
// Process Tool Config
|
||||
toolWrapper := struct {
|
||||
Kind string `yaml:"kind"`
|
||||
Name string `yaml:"name"`
|
||||
Config tools.ToolConfig `yaml:",inline"`
|
||||
}{
|
||||
Kind: "tools",
|
||||
Name: toolName,
|
||||
Config: toolCfg,
|
||||
}
|
||||
|
||||
if err := encoder.Encode(toolWrapper); err != nil {
|
||||
return nil, fmt.Errorf("error encoding tool config: %w", err)
|
||||
}
|
||||
|
||||
// Process Source Config
|
||||
var toolMap map[string]interface{}
|
||||
b, err := yaml.Marshal(toolCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling tool config: %w", err)
|
||||
}
|
||||
if err := yaml.Unmarshal(b, &toolMap); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling tool config map: %w", err)
|
||||
}
|
||||
|
||||
if sourceName, ok := toolMap["source"].(string); ok {
|
||||
if sourceCfg, ok := cfg.SourceConfigs[sourceName]; ok {
|
||||
sourceWrapper := struct {
|
||||
Kind string `yaml:"kind"`
|
||||
Name string `yaml:"name"`
|
||||
Config sources.SourceConfig `yaml:",inline"`
|
||||
}{
|
||||
Kind: "sources",
|
||||
Name: sourceName,
|
||||
Config: sourceCfg,
|
||||
}
|
||||
|
||||
if err := encoder.Encode(sourceWrapper); err != nil {
|
||||
return nil, fmt.Errorf("error encoding source config: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
338
internal/cli/skills/generator_test.go
Normal file
338
internal/cli/skills/generator_test.go
Normal file
@@ -0,0 +1,338 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package skills
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
type MockToolConfig struct {
|
||||
Type string `yaml:"type"`
|
||||
Source string `yaml:"source"`
|
||||
Other string `yaml:"other"`
|
||||
Parameters parameters.Parameters `yaml:"parameters"`
|
||||
}
|
||||
|
||||
func (m MockToolConfig) ToolConfigType() string {
|
||||
return m.Type
|
||||
}
|
||||
|
||||
func (m MockToolConfig) Initialize(map[string]sources.Source) (tools.Tool, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type MockSourceConfig struct {
|
||||
TypeVal string
|
||||
ConnVal string
|
||||
}
|
||||
|
||||
func (m MockSourceConfig) SourceConfigType() string {
|
||||
return m.TypeVal
|
||||
}
|
||||
|
||||
func (m MockSourceConfig) Initialize(context.Context, trace.Tracer) (sources.Source, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m MockSourceConfig) MarshalYAML() (interface{}, error) {
|
||||
return map[string]interface{}{
|
||||
"connection_string": m.ConnVal,
|
||||
"type": m.TypeVal,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestFormatParameters(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
params []parameters.ParameterManifest
|
||||
wantContains []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty parameters",
|
||||
params: []parameters.ParameterManifest{},
|
||||
wantContains: []string{""},
|
||||
},
|
||||
{
|
||||
name: "single required string parameter",
|
||||
params: []parameters.ParameterManifest{
|
||||
{
|
||||
Name: "param1",
|
||||
Description: "A test parameter",
|
||||
Type: "string",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
wantContains: []string{
|
||||
"## Parameters",
|
||||
"```json",
|
||||
`"type": "object"`,
|
||||
`"properties": {`,
|
||||
`"param1": {`,
|
||||
`"type": "string"`,
|
||||
`"description": "A test parameter"`,
|
||||
`"required": [`,
|
||||
`"param1"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mixed parameters with defaults",
|
||||
params: []parameters.ParameterManifest{
|
||||
{
|
||||
Name: "param1",
|
||||
Description: "Param 1",
|
||||
Type: "string",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "param2",
|
||||
Description: "Param 2",
|
||||
Type: "integer",
|
||||
Default: 42,
|
||||
Required: false,
|
||||
},
|
||||
},
|
||||
wantContains: []string{
|
||||
`"param1": {`,
|
||||
`"param2": {`,
|
||||
`"default": 42`,
|
||||
`"required": [`,
|
||||
`"param1"`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := formatParameters(tt.params)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("formatParameters() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
|
||||
if len(tt.params) == 0 {
|
||||
if got != "" {
|
||||
t.Errorf("formatParameters() = %v, want empty string", got)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, want := range tt.wantContains {
|
||||
if !strings.Contains(got, want) {
|
||||
t.Errorf("formatParameters() result missing expected string: %s\nGot:\n%s", want, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSkillMarkdown(t *testing.T) {
|
||||
toolsMap := map[string]tools.Tool{
|
||||
"tool1": server.MockTool{
|
||||
Description: "First tool",
|
||||
Params: []parameters.Parameter{
|
||||
parameters.NewStringParameter("p1", "d1"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got, err := generateSkillMarkdown("MySkill", "My Description", toolsMap)
|
||||
if err != nil {
|
||||
t.Fatalf("generateSkillMarkdown() error = %v", err)
|
||||
}
|
||||
|
||||
expectedSubstrings := []string{
|
||||
"name: MySkill",
|
||||
"description: My Description",
|
||||
"## Usage",
|
||||
"All scripts can be executed using Node.js",
|
||||
"**Bash:**",
|
||||
"`node scripts/<script_name>.js '{\"<param_name>\": \"<param_value>\"}'`",
|
||||
"**PowerShell:**",
|
||||
"`node scripts/<script_name>.js '{\"<param_name>\": \"<param_value>\"}'`",
|
||||
"## Scripts",
|
||||
"### tool1",
|
||||
"First tool",
|
||||
"## Parameters",
|
||||
}
|
||||
|
||||
for _, s := range expectedSubstrings {
|
||||
if !strings.Contains(got, s) {
|
||||
t.Errorf("generateSkillMarkdown() missing substring %q", s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateScriptContent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
toolName string
|
||||
toolsFileName string
|
||||
wantContains []string
|
||||
}{
|
||||
{
|
||||
name: "basic script",
|
||||
toolName: "test-tool",
|
||||
toolsFileName: "",
|
||||
wantContains: []string{
|
||||
`const toolName = "test-tool";`,
|
||||
`const toolsFileName = "";`,
|
||||
`const toolboxArgs = [...configArgs, "invoke", toolName, ...args];`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "script with tools file",
|
||||
toolName: "complex-tool",
|
||||
toolsFileName: "tools.yaml",
|
||||
wantContains: []string{
|
||||
`const toolName = "complex-tool";`,
|
||||
`const toolsFileName = "tools.yaml";`,
|
||||
`configArgs.push("--tools-file", path.join(__dirname, "..", "assets", toolsFileName));`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := generateScriptContent(tt.toolName, tt.toolsFileName)
|
||||
if err != nil {
|
||||
t.Fatalf("generateScriptContent() error = %v", err)
|
||||
}
|
||||
|
||||
for _, s := range tt.wantContains {
|
||||
if !strings.Contains(got, s) {
|
||||
t.Errorf("generateScriptContent() missing substring %q\nGot:\n%s", s, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateToolConfigYAML(t *testing.T) {
|
||||
cfg := server.ServerConfig{
|
||||
ToolConfigs: server.ToolConfigs{
|
||||
"tool1": MockToolConfig{
|
||||
Type: "custom-tool",
|
||||
Source: "src1",
|
||||
Other: "foo",
|
||||
},
|
||||
"toolNoSource": MockToolConfig{
|
||||
Type: "http",
|
||||
},
|
||||
"toolWithParams": MockToolConfig{
|
||||
Type: "custom-tool",
|
||||
Parameters: []parameters.Parameter{
|
||||
parameters.NewStringParameter("param1", "desc1"),
|
||||
},
|
||||
},
|
||||
},
|
||||
SourceConfigs: server.SourceConfigs{
|
||||
"src1": MockSourceConfig{
|
||||
TypeVal: "postgres",
|
||||
ConnVal: "conn1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
toolName string
|
||||
wantContains []string
|
||||
wantErr bool
|
||||
wantNil bool
|
||||
}{
|
||||
{
|
||||
name: "tool with source",
|
||||
toolName: "tool1",
|
||||
wantContains: []string{
|
||||
"kind: tools",
|
||||
"name: tool1",
|
||||
"type: custom-tool",
|
||||
"source: src1",
|
||||
"other: foo",
|
||||
"---",
|
||||
"kind: sources",
|
||||
"name: src1",
|
||||
"type: postgres",
|
||||
"connection_string: conn1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool without source",
|
||||
toolName: "toolNoSource",
|
||||
wantContains: []string{
|
||||
"kind: tools",
|
||||
"name: toolNoSource",
|
||||
"type: http",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool with parameters",
|
||||
toolName: "toolWithParams",
|
||||
wantContains: []string{
|
||||
"kind: tools",
|
||||
"name: toolWithParams",
|
||||
"type: custom-tool",
|
||||
"parameters:",
|
||||
"- name: param1",
|
||||
"type: string",
|
||||
"description: desc1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non-existent tool",
|
||||
toolName: "missing-tool",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotBytes, err := generateToolConfigYAML(cfg, tt.toolName)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("generateToolConfigYAML() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantNil {
|
||||
if gotBytes != nil {
|
||||
t.Errorf("generateToolConfigYAML() expected nil, got %s", string(gotBytes))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
got := string(gotBytes)
|
||||
for _, want := range tt.wantContains {
|
||||
if !strings.Contains(got, want) {
|
||||
t.Errorf("generateToolConfigYAML() result missing expected string: %q\nGot:\n%s", want, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
@@ -234,10 +235,8 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
params, err := parameters.ParseParams(tool.GetParameters(), data, claimsFromAuth)
|
||||
if err != nil {
|
||||
// If auth error, return 401
|
||||
errMsg := fmt.Sprintf("error parsing authenticated parameters from ID token: %w", err)
|
||||
var clientServerErr *util.ClientServerError
|
||||
if errors.As(err, &clientServerErr) && clientServerErr.Code == http.StatusUnauthorized {
|
||||
s.logger.DebugContext(ctx, errMsg)
|
||||
if errors.Is(err, util.ErrUnauthorized) {
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err))
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
|
||||
return
|
||||
}
|
||||
@@ -260,49 +259,34 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Determine what error to return to the users.
|
||||
if err != nil {
|
||||
var tbErr util.ToolboxError
|
||||
errStr := err.Error()
|
||||
var statusCode int
|
||||
|
||||
if errors.As(err, &tbErr) {
|
||||
switch tbErr.Category() {
|
||||
case util.CategoryAgent:
|
||||
// Agent Errors -> 200 OK
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("Tool invocation agent error: %v", err))
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusOK))
|
||||
return
|
||||
// Upstream API auth error propagation
|
||||
switch {
|
||||
case strings.Contains(errStr, "Error 401"):
|
||||
statusCode = http.StatusUnauthorized
|
||||
case strings.Contains(errStr, "Error 403"):
|
||||
statusCode = http.StatusForbidden
|
||||
}
|
||||
|
||||
case util.CategoryServer:
|
||||
// Server Errors -> Check the specific code inside
|
||||
var clientServerErr *util.ClientServerError
|
||||
statusCode := http.StatusInternalServerError // Default to 500
|
||||
|
||||
if errors.As(err, &clientServerErr) {
|
||||
if clientServerErr.Code != 0 {
|
||||
statusCode = clientServerErr.Code
|
||||
}
|
||||
}
|
||||
|
||||
// Process auth error
|
||||
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
||||
if clientAuth {
|
||||
// Token error, pass through 401/403
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("Client credentials lack authorization: %v", err))
|
||||
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
||||
return
|
||||
}
|
||||
// ADC/Config error, return 500
|
||||
statusCode = http.StatusInternalServerError
|
||||
}
|
||||
|
||||
s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation server error: %v", err))
|
||||
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
||||
if clientAuth {
|
||||
// Propagate the original 401/403 error.
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err))
|
||||
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Unknown error -> 500
|
||||
s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation unknown error: %v", err))
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
|
||||
// ADC lacking permission or credentials configuration error.
|
||||
internalErr := fmt.Errorf("unexpected auth error occured during Tool invocation: %w", err)
|
||||
s.logger.ErrorContext(ctx, internalErr.Error())
|
||||
_ = render.Render(w, r, newErrResponse(internalErr, http.StatusInternalServerError))
|
||||
return
|
||||
}
|
||||
err = fmt.Errorf("error while invoking tool: %w", err)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
||||
return
|
||||
}
|
||||
|
||||
resMarshal, err := json.Marshal(res)
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||
@@ -41,140 +40,6 @@ var (
|
||||
_ prompts.Prompt = MockPrompt{}
|
||||
)
|
||||
|
||||
// MockTool is used to mock tools in tests
|
||||
type MockTool struct {
|
||||
Name string
|
||||
Description string
|
||||
Params []parameters.Parameter
|
||||
manifest tools.Manifest
|
||||
unauthorized bool
|
||||
requiresClientAuthrorization bool
|
||||
}
|
||||
|
||||
func (t MockTool) Invoke(context.Context, tools.SourceProvider, parameters.ParamValues, tools.AccessToken) (any, error) {
|
||||
mock := []any{t.Name}
|
||||
return mock, nil
|
||||
}
|
||||
|
||||
func (t MockTool) ToConfig() tools.ToolConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
// claims is a map of user info decoded from an auth token
|
||||
func (t MockTool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
return parameters.ParseParams(t.Params, data, claimsMap)
|
||||
}
|
||||
|
||||
func (t MockTool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
return parameters.EmbedParams(ctx, t.Params, paramValues, embeddingModelsMap, nil)
|
||||
}
|
||||
|
||||
func (t MockTool) Manifest() tools.Manifest {
|
||||
pMs := make([]parameters.ParameterManifest, 0, len(t.Params))
|
||||
for _, p := range t.Params {
|
||||
pMs = append(pMs, p.Manifest())
|
||||
}
|
||||
return tools.Manifest{Description: t.Description, Parameters: pMs}
|
||||
}
|
||||
|
||||
func (t MockTool) Authorized(verifiedAuthServices []string) bool {
|
||||
// defaulted to true
|
||||
return !t.unauthorized
|
||||
}
|
||||
|
||||
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) (bool, error) {
|
||||
// defaulted to false
|
||||
return t.requiresClientAuthrorization, nil
|
||||
}
|
||||
|
||||
func (t MockTool) GetParameters() parameters.Parameters {
|
||||
return t.Params
|
||||
}
|
||||
|
||||
func (t MockTool) McpManifest() tools.McpManifest {
|
||||
properties := make(map[string]parameters.ParameterMcpManifest)
|
||||
required := make([]string, 0)
|
||||
authParams := make(map[string][]string)
|
||||
|
||||
for _, p := range t.Params {
|
||||
name := p.GetName()
|
||||
paramManifest, authParamList := p.McpManifest()
|
||||
properties[name] = paramManifest
|
||||
required = append(required, name)
|
||||
|
||||
if len(authParamList) > 0 {
|
||||
authParams[name] = authParamList
|
||||
}
|
||||
}
|
||||
|
||||
toolsSchema := parameters.McpToolsSchema{
|
||||
Type: "object",
|
||||
Properties: properties,
|
||||
Required: required,
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
InputSchema: toolsSchema,
|
||||
}
|
||||
|
||||
if len(authParams) > 0 {
|
||||
mcpManifest.Metadata = map[string]any{
|
||||
"toolbox/authParams": authParams,
|
||||
}
|
||||
}
|
||||
|
||||
return mcpManifest
|
||||
}
|
||||
|
||||
func (t MockTool) GetAuthTokenHeaderName(tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
// MockPrompt is used to mock prompts in tests
|
||||
type MockPrompt struct {
|
||||
Name string
|
||||
Description string
|
||||
Args prompts.Arguments
|
||||
}
|
||||
|
||||
func (p MockPrompt) SubstituteParams(vals parameters.ParamValues) (any, error) {
|
||||
return []prompts.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: fmt.Sprintf("substituted %s", p.Name),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p MockPrompt) ParseArgs(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
var params parameters.Parameters
|
||||
for _, arg := range p.Args {
|
||||
params = append(params, arg.Parameter)
|
||||
}
|
||||
return parameters.ParseParams(params, data, claimsMap)
|
||||
}
|
||||
|
||||
func (p MockPrompt) Manifest() prompts.Manifest {
|
||||
var argManifests []parameters.ParameterManifest
|
||||
for _, arg := range p.Args {
|
||||
argManifests = append(argManifests, arg.Manifest())
|
||||
}
|
||||
return prompts.Manifest{
|
||||
Description: p.Description,
|
||||
Arguments: argManifests,
|
||||
}
|
||||
}
|
||||
|
||||
func (p MockPrompt) McpManifest() prompts.McpManifest {
|
||||
return prompts.GetMcpManifest(p.Name, p.Description, p.Args)
|
||||
}
|
||||
|
||||
func (p MockPrompt) ToConfig() prompts.PromptConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
var tool1 = MockTool{
|
||||
Name: "no_params",
|
||||
Params: []parameters.Parameter{},
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -443,17 +444,15 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
code := rpcResponse.Error.Code
|
||||
switch code {
|
||||
case jsonrpc.INTERNAL_ERROR:
|
||||
// Map Internal RPC Error (-32603) to HTTP 500
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
case jsonrpc.INVALID_REQUEST:
|
||||
var clientServerErr *util.ClientServerError
|
||||
if errors.As(err, &clientServerErr) {
|
||||
switch clientServerErr.Code {
|
||||
case http.StatusUnauthorized:
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
case http.StatusForbidden:
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
}
|
||||
errStr := err.Error()
|
||||
if errors.Is(err, util.ErrUnauthorized) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
} else if strings.Contains(errStr, "Error 401") {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
} else if strings.Contains(errStr, "Error 403") {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
@@ -123,11 +124,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.NewClientServerError(
|
||||
"missing access token in the 'Authorization' header",
|
||||
http.StatusUnauthorized,
|
||||
nil,
|
||||
)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,11 +172,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
// Check if any of the specified auth services is verified
|
||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||
if !isAuthorized {
|
||||
err = util.NewClientServerError(
|
||||
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||
http.StatusUnauthorized,
|
||||
nil,
|
||||
)
|
||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
logger.DebugContext(ctx, "tool invocation authorized")
|
||||
@@ -201,44 +194,30 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||
if err != nil {
|
||||
var tbErr util.ToolboxError
|
||||
|
||||
if errors.As(err, &tbErr) {
|
||||
switch tbErr.Category() {
|
||||
case util.CategoryAgent:
|
||||
// MCP - Tool execution error
|
||||
// Return SUCCESS but with IsError: true
|
||||
text := TextContent{
|
||||
Type: "text",
|
||||
Text: err.Error(),
|
||||
}
|
||||
return jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||
}, nil
|
||||
|
||||
case util.CategoryServer:
|
||||
// MCP Spec - Protocol error
|
||||
// Return JSON-RPC ERROR
|
||||
var clientServerErr *util.ClientServerError
|
||||
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
||||
|
||||
if errors.As(err, &clientServerErr) {
|
||||
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
||||
if clientAuth {
|
||||
rpcCode = jsonrpc.INVALID_REQUEST
|
||||
} else {
|
||||
rpcCode = jsonrpc.INTERNAL_ERROR
|
||||
}
|
||||
}
|
||||
}
|
||||
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
||||
errStr := err.Error()
|
||||
// Missing authService tokens.
|
||||
if errors.Is(err, util.ErrUnauthorized) {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if clientAuth {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
} else {
|
||||
// Unknown error -> 500
|
||||
// Auth error with ADC should raise internal 500 error
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
|
||||
text := TextContent{
|
||||
Type: "text",
|
||||
Text: err.Error(),
|
||||
}
|
||||
return jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||
}, nil
|
||||
}
|
||||
|
||||
content := make([]TextContent, 0)
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
@@ -123,11 +124,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.NewClientServerError(
|
||||
"missing access token in the 'Authorization' header",
|
||||
http.StatusUnauthorized,
|
||||
nil,
|
||||
)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,11 +172,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
// Check if any of the specified auth services is verified
|
||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||
if !isAuthorized {
|
||||
err = util.NewClientServerError(
|
||||
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||
http.StatusUnauthorized,
|
||||
nil,
|
||||
)
|
||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
logger.DebugContext(ctx, "tool invocation authorized")
|
||||
@@ -201,45 +194,31 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||
if err != nil {
|
||||
var tbErr util.ToolboxError
|
||||
|
||||
if errors.As(err, &tbErr) {
|
||||
switch tbErr.Category() {
|
||||
case util.CategoryAgent:
|
||||
// MCP - Tool execution error
|
||||
// Return SUCCESS but with IsError: true
|
||||
text := TextContent{
|
||||
Type: "text",
|
||||
Text: err.Error(),
|
||||
}
|
||||
return jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||
}, nil
|
||||
|
||||
case util.CategoryServer:
|
||||
// MCP Spec - Protocol error
|
||||
// Return JSON-RPC ERROR
|
||||
var clientServerErr *util.ClientServerError
|
||||
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
||||
|
||||
if errors.As(err, &clientServerErr) {
|
||||
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
||||
if clientAuth {
|
||||
rpcCode = jsonrpc.INVALID_REQUEST
|
||||
} else {
|
||||
rpcCode = jsonrpc.INTERNAL_ERROR
|
||||
}
|
||||
}
|
||||
}
|
||||
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
||||
errStr := err.Error()
|
||||
// Missing authService tokens.
|
||||
if errors.Is(err, util.ErrUnauthorized) {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if clientAuth {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
} else {
|
||||
// Unknown error -> 500
|
||||
// Auth error with ADC should raise internal 500 error
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
text := TextContent{
|
||||
Type: "text",
|
||||
Text: err.Error(),
|
||||
}
|
||||
return jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||
}, nil
|
||||
}
|
||||
|
||||
content := make([]TextContent, 0)
|
||||
|
||||
sliceRes, ok := results.([]any)
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
@@ -116,12 +117,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
errMsg := "missing access token in the 'Authorization' header"
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, errMsg, nil), util.NewClientServerError(
|
||||
errMsg,
|
||||
http.StatusUnauthorized,
|
||||
nil,
|
||||
)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,11 +165,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
// Check if any of the specified auth services is verified
|
||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||
if !isAuthorized {
|
||||
err = util.NewClientServerError(
|
||||
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||
http.StatusUnauthorized,
|
||||
nil,
|
||||
)
|
||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
logger.DebugContext(ctx, "tool invocation authorized")
|
||||
@@ -195,44 +187,29 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||
if err != nil {
|
||||
var tbErr util.ToolboxError
|
||||
|
||||
if errors.As(err, &tbErr) {
|
||||
switch tbErr.Category() {
|
||||
case util.CategoryAgent:
|
||||
// MCP - Tool execution error
|
||||
// Return SUCCESS but with IsError: true
|
||||
text := TextContent{
|
||||
Type: "text",
|
||||
Text: err.Error(),
|
||||
}
|
||||
return jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||
}, nil
|
||||
|
||||
case util.CategoryServer:
|
||||
// MCP Spec - Protocol error
|
||||
// Return JSON-RPC ERROR
|
||||
var clientServerErr *util.ClientServerError
|
||||
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
||||
|
||||
if errors.As(err, &clientServerErr) {
|
||||
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
||||
if clientAuth {
|
||||
rpcCode = jsonrpc.INVALID_REQUEST
|
||||
} else {
|
||||
rpcCode = jsonrpc.INTERNAL_ERROR
|
||||
}
|
||||
}
|
||||
}
|
||||
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
||||
errStr := err.Error()
|
||||
// Missing authService tokens.
|
||||
if errors.Is(err, util.ErrUnauthorized) {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if clientAuth {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
} else {
|
||||
// Unknown error -> 500
|
||||
// Auth error with ADC should raise internal 500 error
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
text := TextContent{
|
||||
Type: "text",
|
||||
Text: err.Error(),
|
||||
}
|
||||
return jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||
}, nil
|
||||
}
|
||||
|
||||
content := make([]TextContent, 0)
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
@@ -116,11 +117,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.NewClientServerError(
|
||||
"missing access token in the 'Authorization' header",
|
||||
http.StatusUnauthorized,
|
||||
nil,
|
||||
)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||
}
|
||||
}
|
||||
|
||||
@@ -168,11 +165,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
// Check if any of the specified auth services is verified
|
||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||
if !isAuthorized {
|
||||
err = util.NewClientServerError(
|
||||
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||
http.StatusUnauthorized,
|
||||
nil,
|
||||
)
|
||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
logger.DebugContext(ctx, "tool invocation authorized")
|
||||
@@ -194,44 +187,29 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||
if err != nil {
|
||||
var tbErr util.ToolboxError
|
||||
|
||||
if errors.As(err, &tbErr) {
|
||||
switch tbErr.Category() {
|
||||
case util.CategoryAgent:
|
||||
// MCP - Tool execution error
|
||||
// Return SUCCESS but with IsError: true
|
||||
text := TextContent{
|
||||
Type: "text",
|
||||
Text: err.Error(),
|
||||
}
|
||||
return jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||
}, nil
|
||||
|
||||
case util.CategoryServer:
|
||||
// MCP Spec - Protocol error
|
||||
// Return JSON-RPC ERROR
|
||||
var clientServerErr *util.ClientServerError
|
||||
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
||||
|
||||
if errors.As(err, &clientServerErr) {
|
||||
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
||||
if clientAuth {
|
||||
rpcCode = jsonrpc.INVALID_REQUEST
|
||||
} else {
|
||||
rpcCode = jsonrpc.INTERNAL_ERROR
|
||||
}
|
||||
}
|
||||
}
|
||||
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
||||
errStr := err.Error()
|
||||
// Missing authService tokens.
|
||||
if errors.Is(err, util.ErrUnauthorized) {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if clientAuth {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
} else {
|
||||
// Unknown error -> 500
|
||||
// Auth error with ADC should raise internal 500 error
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
text := TextContent{
|
||||
Type: "text",
|
||||
Text: err.Error(),
|
||||
}
|
||||
return jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||
}, nil
|
||||
}
|
||||
|
||||
content := make([]TextContent, 0)
|
||||
|
||||
159
internal/server/mocks.go
Normal file
159
internal/server/mocks.go
Normal file
@@ -0,0 +1,159 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
// MockTool is used to mock tools in tests
|
||||
type MockTool struct {
|
||||
Name string
|
||||
Description string
|
||||
Params []parameters.Parameter
|
||||
manifest tools.Manifest
|
||||
unauthorized bool
|
||||
requiresClientAuthrorization bool
|
||||
}
|
||||
|
||||
func (t MockTool) Invoke(context.Context, tools.SourceProvider, parameters.ParamValues, tools.AccessToken) (any, error) {
|
||||
mock := []any{t.Name}
|
||||
return mock, nil
|
||||
}
|
||||
|
||||
func (t MockTool) ToConfig() tools.ToolConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
// claims is a map of user info decoded from an auth token
|
||||
func (t MockTool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
return parameters.ParseParams(t.Params, data, claimsMap)
|
||||
}
|
||||
|
||||
func (t MockTool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
return parameters.EmbedParams(ctx, t.Params, paramValues, embeddingModelsMap, nil)
|
||||
}
|
||||
|
||||
func (t MockTool) Manifest() tools.Manifest {
|
||||
pMs := make([]parameters.ParameterManifest, 0, len(t.Params))
|
||||
for _, p := range t.Params {
|
||||
pMs = append(pMs, p.Manifest())
|
||||
}
|
||||
return tools.Manifest{Description: t.Description, Parameters: pMs}
|
||||
}
|
||||
|
||||
func (t MockTool) Authorized(verifiedAuthServices []string) bool {
|
||||
// defaulted to true
|
||||
return !t.unauthorized
|
||||
}
|
||||
|
||||
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) (bool, error) {
|
||||
// defaulted to false
|
||||
return t.requiresClientAuthrorization, nil
|
||||
}
|
||||
|
||||
func (t MockTool) GetParameters() parameters.Parameters {
|
||||
return t.Params
|
||||
}
|
||||
|
||||
func (t MockTool) McpManifest() tools.McpManifest {
|
||||
properties := make(map[string]parameters.ParameterMcpManifest)
|
||||
required := make([]string, 0)
|
||||
authParams := make(map[string][]string)
|
||||
|
||||
for _, p := range t.Params {
|
||||
name := p.GetName()
|
||||
paramManifest, authParamList := p.McpManifest()
|
||||
properties[name] = paramManifest
|
||||
required = append(required, name)
|
||||
|
||||
if len(authParamList) > 0 {
|
||||
authParams[name] = authParamList
|
||||
}
|
||||
}
|
||||
|
||||
toolsSchema := parameters.McpToolsSchema{
|
||||
Type: "object",
|
||||
Properties: properties,
|
||||
Required: required,
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
InputSchema: toolsSchema,
|
||||
}
|
||||
|
||||
if len(authParams) > 0 {
|
||||
mcpManifest.Metadata = map[string]any{
|
||||
"toolbox/authParams": authParams,
|
||||
}
|
||||
}
|
||||
|
||||
return mcpManifest
|
||||
}
|
||||
|
||||
func (t MockTool) GetAuthTokenHeaderName(tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
|
||||
// MockPrompt is used to mock prompts in tests
|
||||
type MockPrompt struct {
|
||||
Name string
|
||||
Description string
|
||||
Args prompts.Arguments
|
||||
}
|
||||
|
||||
func (p MockPrompt) SubstituteParams(vals parameters.ParamValues) (any, error) {
|
||||
return []prompts.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: fmt.Sprintf("substituted %s", p.Name),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p MockPrompt) ParseArgs(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
var params parameters.Parameters
|
||||
for _, arg := range p.Args {
|
||||
params = append(params, arg.Parameter)
|
||||
}
|
||||
return parameters.ParseParams(params, data, claimsMap)
|
||||
}
|
||||
|
||||
func (p MockPrompt) Manifest() prompts.Manifest {
|
||||
var argManifests []parameters.ParameterManifest
|
||||
for _, arg := range p.Args {
|
||||
argManifests = append(argManifests, arg.Manifest())
|
||||
}
|
||||
return prompts.Manifest{
|
||||
Description: p.Description,
|
||||
Arguments: argManifests,
|
||||
}
|
||||
}
|
||||
|
||||
func (p MockPrompt) McpManifest() prompts.McpManifest {
|
||||
return prompts.GetMcpManifest(p.Name, p.Description, p.Args)
|
||||
}
|
||||
|
||||
func (p MockPrompt) ToConfig() prompts.PromptConfig {
|
||||
return nil
|
||||
}
|
||||
@@ -184,7 +184,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if source.UseClientAuthorization() {
|
||||
// Use client-side access token
|
||||
if accessToken == "" {
|
||||
return nil, util.NewClientServerError("tool is configured for client OAuth but no token was provided in the request header", http.StatusUnauthorized, nil)
|
||||
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", util.ErrUnauthorized)
|
||||
}
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
|
||||
@@ -17,7 +17,6 @@ package tools
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
@@ -81,7 +80,7 @@ type AccessToken string
|
||||
func (token AccessToken) ParseBearerToken() (string, error) {
|
||||
headerParts := strings.Split(string(token), " ")
|
||||
if len(headerParts) != 2 || strings.ToLower(headerParts[0]) != "bearer" {
|
||||
return "", util.NewClientServerError("authorization header must be in the format 'Bearer <token>'", http.StatusUnauthorized, nil)
|
||||
return "", fmt.Errorf("authorization header must be in the format 'Bearer <token>': %w", util.ErrUnauthorized)
|
||||
}
|
||||
return headerParts[1], nil
|
||||
}
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
// Copyright 2026 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package util
|
||||
|
||||
import "fmt"
|
||||
|
||||
type ErrorCategory string
|
||||
|
||||
const (
|
||||
CategoryAgent ErrorCategory = "AGENT_ERROR"
|
||||
CategoryServer ErrorCategory = "SERVER_ERROR"
|
||||
)
|
||||
|
||||
// ToolboxError is the interface all custom errors must satisfy
|
||||
type ToolboxError interface {
|
||||
error
|
||||
Category() ErrorCategory
|
||||
}
|
||||
|
||||
// Agent Errors return 200 to the sender
|
||||
type AgentError struct {
|
||||
Msg string
|
||||
Cause error
|
||||
}
|
||||
|
||||
func (e *AgentError) Error() string { return e.Msg }
|
||||
|
||||
func (e *AgentError) Category() ErrorCategory { return CategoryAgent }
|
||||
|
||||
func (e *AgentError) Unwrap() error { return e.Cause }
|
||||
|
||||
func NewAgentError(msg string, cause error) *AgentError {
|
||||
return &AgentError{Msg: msg, Cause: cause}
|
||||
}
|
||||
|
||||
// ClientServerError returns 4XX/5XX error code
|
||||
type ClientServerError struct {
|
||||
Msg string
|
||||
Code int
|
||||
Cause error
|
||||
}
|
||||
|
||||
func (e *ClientServerError) Error() string { return fmt.Sprintf("%s: %v", e.Msg, e.Cause) }
|
||||
|
||||
func (e *ClientServerError) Category() ErrorCategory { return CategoryServer }
|
||||
|
||||
func (e *ClientServerError) Unwrap() error { return e.Cause }
|
||||
|
||||
func NewClientServerError(msg string, code int, cause error) *ClientServerError {
|
||||
return &ClientServerError{Msg: msg, Code: code, Cause: cause}
|
||||
}
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"slices"
|
||||
@@ -119,7 +118,7 @@ func parseFromAuthService(paramAuthServices []ParamAuthService, claimsMap map[st
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
return nil, util.NewClientServerError("missing or invalid authentication header", http.StatusUnauthorized, nil)
|
||||
return nil, fmt.Errorf("missing or invalid authentication header: %w", util.ErrUnauthorized)
|
||||
}
|
||||
|
||||
// CheckParamRequired checks if a parameter is required based on the required and default field.
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -187,3 +188,5 @@ func InstrumentationFromContext(ctx context.Context) (*telemetry.Instrumentation
|
||||
}
|
||||
return nil, fmt.Errorf("unable to retrieve instrumentation")
|
||||
}
|
||||
|
||||
var ErrUnauthorized = errors.New("unauthorized")
|
||||
|
||||
Reference in New Issue
Block a user