mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-04 12:15:09 -05:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d903aafd31 | ||
|
|
3f1908a822 | ||
|
|
eef7a94977 |
@@ -35,6 +35,7 @@ import (
|
|||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||||
"github.com/googleapis/genai-toolbox/internal/cli/invoke"
|
"github.com/googleapis/genai-toolbox/internal/cli/invoke"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/cli/skills"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/log"
|
"github.com/googleapis/genai-toolbox/internal/log"
|
||||||
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
||||||
@@ -401,6 +402,8 @@ func NewCommand(opts ...Option) *Command {
|
|||||||
|
|
||||||
// Register subcommands for tool invocation
|
// Register subcommands for tool invocation
|
||||||
baseCmd.AddCommand(invoke.NewCommand(cmd))
|
baseCmd.AddCommand(invoke.NewCommand(cmd))
|
||||||
|
// Register subcommands for skill generation
|
||||||
|
baseCmd.AddCommand(skills.NewCommand(cmd))
|
||||||
|
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
|||||||
179
cmd/skill_generate_test.go
Normal file
179
cmd/skill_generate_test.go
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateSkill(t *testing.T) {
|
||||||
|
// Create a temporary directory for tests
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
outputDir := filepath.Join(tmpDir, "skills")
|
||||||
|
|
||||||
|
// Create a tools.yaml file with a sqlite tool
|
||||||
|
toolsFileContent := `
|
||||||
|
sources:
|
||||||
|
my-sqlite:
|
||||||
|
kind: sqlite
|
||||||
|
database: test.db
|
||||||
|
tools:
|
||||||
|
hello-sqlite:
|
||||||
|
kind: sqlite-sql
|
||||||
|
source: my-sqlite
|
||||||
|
description: "hello tool"
|
||||||
|
statement: "SELECT 'hello' as greeting"
|
||||||
|
`
|
||||||
|
|
||||||
|
toolsFilePath := filepath.Join(tmpDir, "tools.yaml")
|
||||||
|
if err := os.WriteFile(toolsFilePath, []byte(toolsFileContent), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write tools file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
args := []string{
|
||||||
|
"skills-generate",
|
||||||
|
"--tools-file", toolsFilePath,
|
||||||
|
"--output-dir", outputDir,
|
||||||
|
"--name", "hello-sqlite",
|
||||||
|
"--description", "hello tool",
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, got, err := invokeCommandWithContext(ctx, args)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("command failed: %v\nOutput: %s", err, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify generated directory structure
|
||||||
|
skillPath := filepath.Join(outputDir, "hello-sqlite")
|
||||||
|
if _, err := os.Stat(skillPath); os.IsNotExist(err) {
|
||||||
|
t.Fatalf("skill directory not created: %s", skillPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check SKILL.md
|
||||||
|
skillMarkdown := filepath.Join(skillPath, "SKILL.md")
|
||||||
|
content, err := os.ReadFile(skillMarkdown)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read SKILL.md: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedFrontmatter := `---
|
||||||
|
name: hello-sqlite
|
||||||
|
description: hello tool
|
||||||
|
---`
|
||||||
|
if !strings.HasPrefix(string(content), expectedFrontmatter) {
|
||||||
|
t.Errorf("SKILL.md does not have expected frontmatter format.\nExpected prefix:\n%s\nGot:\n%s", expectedFrontmatter, string(content))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(string(content), "## Usage") {
|
||||||
|
t.Errorf("SKILL.md does not contain '## Usage' section")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(string(content), "## Scripts") {
|
||||||
|
t.Errorf("SKILL.md does not contain '## Scripts' section")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(string(content), "### hello-sqlite") {
|
||||||
|
t.Errorf("SKILL.md does not contain '### hello-sqlite' tool header")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check script file
|
||||||
|
scriptFilename := "hello-sqlite.js"
|
||||||
|
scriptPath := filepath.Join(skillPath, "scripts", scriptFilename)
|
||||||
|
if _, err := os.Stat(scriptPath); os.IsNotExist(err) {
|
||||||
|
t.Fatalf("script file not created: %s", scriptPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
scriptContent, err := os.ReadFile(scriptPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read script file: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(scriptContent), "hello-sqlite") {
|
||||||
|
t.Errorf("script file does not contain expected tool name")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check assets
|
||||||
|
assetPath := filepath.Join(skillPath, "assets", "hello-sqlite.yaml")
|
||||||
|
if _, err := os.Stat(assetPath); os.IsNotExist(err) {
|
||||||
|
t.Fatalf("asset file not created: %s", assetPath)
|
||||||
|
}
|
||||||
|
assetContent, err := os.ReadFile(assetPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read asset file: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(assetContent), "hello-sqlite") {
|
||||||
|
t.Errorf("asset file does not contain expected tool name")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateSkill_NoConfig(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
outputDir := filepath.Join(tmpDir, "skills")
|
||||||
|
|
||||||
|
args := []string{
|
||||||
|
"skills-generate",
|
||||||
|
"--output-dir", outputDir,
|
||||||
|
"--name", "test",
|
||||||
|
"--description", "test",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err := invokeCommandWithContext(context.Background(), args)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected command to fail when no configuration is provided and tools.yaml is missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should not have created the directory if no config was processed
|
||||||
|
if _, err := os.Stat(outputDir); !os.IsNotExist(err) {
|
||||||
|
t.Errorf("output directory should not have been created")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateSkill_MissingArguments(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
toolsFilePath := filepath.Join(tmpDir, "tools.yaml")
|
||||||
|
if err := os.WriteFile(toolsFilePath, []byte("tools: {}"), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write tools file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing name",
|
||||||
|
args: []string{"skills-generate", "--tools-file", toolsFilePath, "--description", "test"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing description",
|
||||||
|
args: []string{"skills-generate", "--tools-file", toolsFilePath, "--name", "test"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, got, err := invokeCommandWithContext(context.Background(), tt.args)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected command to fail due to missing arguments, but it succeeded\nOutput: %s", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -53,7 +53,7 @@ export async function main() {
|
|||||||
|
|
||||||
for (const query of queries) {
|
for (const query of queries) {
|
||||||
conversationHistory.push({ role: "user", content: [{ text: query }] });
|
conversationHistory.push({ role: "user", content: [{ text: query }] });
|
||||||
const response = await ai.generate({
|
let response = await ai.generate({
|
||||||
messages: conversationHistory,
|
messages: conversationHistory,
|
||||||
tools: tools,
|
tools: tools,
|
||||||
});
|
});
|
||||||
|
|||||||
112
docs/en/how-to/generate_skill.md
Normal file
112
docs/en/how-to/generate_skill.md
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
---
|
||||||
|
title: "Generate Agent Skills"
|
||||||
|
type: docs
|
||||||
|
weight: 10
|
||||||
|
description: >
|
||||||
|
How to generate agent skills from a toolset.
|
||||||
|
---
|
||||||
|
|
||||||
|
The `skills-generate` command allows you to convert a **toolset** into an **Agent Skill**. A toolset is a collection of tools, and the generated skill will contain metadata and execution scripts for all tools within that toolset, complying with the [Agent Skill specification](https://agentskills.io/specification).
|
||||||
|
|
||||||
|
## Before you begin
|
||||||
|
|
||||||
|
1. Make sure you have the `toolbox` executable in your PATH.
|
||||||
|
2. Make sure you have [Node.js](https://nodejs.org/) installed on your system.
|
||||||
|
|
||||||
|
## Generating a Skill from a Toolset
|
||||||
|
|
||||||
|
A skill package consists of a `SKILL.md` file (with required YAML frontmatter) and a set of Node.js scripts. Each tool defined in your toolset maps to a corresponding script in the generated Node.js scripts (`.js`) that work across different platforms (Linux, macOS, Windows).
|
||||||
|
|
||||||
|
|
||||||
|
### Command Usage
|
||||||
|
|
||||||
|
The basic syntax for the command is:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
toolbox <tool-source> skills-generate \
|
||||||
|
--name <skill-name> \
|
||||||
|
--toolset <toolset-name> \
|
||||||
|
--description <description> \
|
||||||
|
--output-dir <output-directory>
|
||||||
|
```
|
||||||
|
|
||||||
|
- `<tool-source>`: Can be `--tools-file`, `--tools-files`, `--tools-folder`, and `--prebuilt`. See the [CLI Reference](../reference/cli.md) for details.
|
||||||
|
- `--name`: Name of the generated skill.
|
||||||
|
- `--description`: Description of the generated skill.
|
||||||
|
- `--toolset`: (Optional) Name of the toolset to convert into a skill. If not provided, all tools will be included.
|
||||||
|
- `--output-dir`: (Optional) Directory to output generated skills (default: "skills").
|
||||||
|
|
||||||
|
{{< notice note >}}
|
||||||
|
**Note:** The `<skill-name>` must follow the Agent Skill [naming convention](https://agentskills.io/specification): it must contain only lowercase alphanumeric characters and hyphens, cannot start or end with a hyphen, and cannot contain consecutive hyphens (e.g., `my-skill`, `data-processing`).
|
||||||
|
{{< /notice >}}
|
||||||
|
|
||||||
|
### Example: Custom Tools File
|
||||||
|
|
||||||
|
1. Create a `tools.yaml` file with a toolset and some tools:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
tools:
|
||||||
|
tool_a:
|
||||||
|
description: "First tool"
|
||||||
|
run:
|
||||||
|
command: "echo 'Tool A'"
|
||||||
|
tool_b:
|
||||||
|
description: "Second tool"
|
||||||
|
run:
|
||||||
|
command: "echo 'Tool B'"
|
||||||
|
toolsets:
|
||||||
|
my_toolset:
|
||||||
|
tools:
|
||||||
|
- tool_a
|
||||||
|
- tool_b
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Generate the skill:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
toolbox --tools-file tools.yaml skills-generate \
|
||||||
|
--name "my-skill" \
|
||||||
|
--toolset "my_toolset" \
|
||||||
|
--description "A skill containing multiple tools" \
|
||||||
|
--output-dir "generated-skills"
|
||||||
|
```
|
||||||
|
|
||||||
|
3. The generated skill directory structure:
|
||||||
|
|
||||||
|
```text
|
||||||
|
generated-skills/
|
||||||
|
└── my-skill/
|
||||||
|
├── SKILL.md
|
||||||
|
├── assets/
|
||||||
|
│ ├── tool_a.yaml
|
||||||
|
│ └── tool_b.yaml
|
||||||
|
└── scripts/
|
||||||
|
├── tool_a.js
|
||||||
|
└── tool_b.js
|
||||||
|
```
|
||||||
|
|
||||||
|
In this example, the skill contains two Node.js scripts (`tool_a.js` and `tool_b.js`), each mapping to a tool in the original toolset.
|
||||||
|
|
||||||
|
### Example: Prebuilt Configuration
|
||||||
|
|
||||||
|
You can also generate skills from prebuilt toolsets:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
toolbox --prebuilt alloydb-postgres-admin skills-generate \
|
||||||
|
--name "alloydb-postgres-admin" \
|
||||||
|
--description "skill for performing administrative operations on alloydb"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Installing the Generated Skill in Gemini CLI
|
||||||
|
|
||||||
|
Once you have generated a skill, you can install it into the Gemini CLI using the `gemini skills install` command.
|
||||||
|
|
||||||
|
### Installation Command
|
||||||
|
|
||||||
|
Provide the path to the directory containing the generated skill:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
gemini skills install /path/to/generated-skills/my-skill
|
||||||
|
```
|
||||||
|
|
||||||
|
Alternatively, use ~/.gemini/skills as the `--output-dir` to generate the skill straight to the Gemini CLI.
|
||||||
@@ -13,21 +13,22 @@ The `invoke` command allows you to invoke tools defined in your configuration di
|
|||||||
|
|
||||||
{{< notice tip >}}
|
{{< notice tip >}}
|
||||||
**Keep configurations minimal:** The `invoke` command initializes *all* resources (sources, tools, etc.) defined in your configuration files during execution. To ensure fast response times, consider using a minimal configuration file containing only the tools you need for the specific invocation.
|
**Keep configurations minimal:** The `invoke` command initializes *all* resources (sources, tools, etc.) defined in your configuration files during execution. To ensure fast response times, consider using a minimal configuration file containing only the tools you need for the specific invocation.
|
||||||
{{< notice tip >}}
|
{{< /notice >}}
|
||||||
|
|
||||||
## Prerequisites
|
## Before you begin
|
||||||
|
|
||||||
- You have the `toolbox` binary installed or built.
|
1. Make sure you have the `toolbox` binary installed or built.
|
||||||
- You have a valid tool configuration file (e.g., `tools.yaml`).
|
2. Make sure you have a valid tool configuration file (e.g., `tools.yaml`).
|
||||||
|
|
||||||
## Basic Usage
|
### Command Usage
|
||||||
|
|
||||||
The basic syntax for the command is:
|
The basic syntax for the command is:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
toolbox [--tools-file <path> | --prebuilt <name>] invoke <tool-name> [params]
|
toolbox <tool-source> invoke <tool-name> [params]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- `<tool-source>`: Can be `--tools-file`, `--tools-files`, `--tools-folder`, and `--prebuilt`. See the [CLI Reference](../reference/cli.md) for details.
|
||||||
- `<tool-name>`: The name of the tool you want to call. This must match the name defined in your `tools.yaml`.
|
- `<tool-name>`: The name of the tool you want to call. This must match the name defined in your `tools.yaml`.
|
||||||
- `[params]`: (Optional) A JSON string representing the arguments for the tool.
|
- `[params]`: (Optional) A JSON string representing the arguments for the tool.
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,8 @@ description: >
|
|||||||
|
|
||||||
## Sub Commands
|
## Sub Commands
|
||||||
|
|
||||||
### `invoke`
|
<details>
|
||||||
|
<summary><code>invoke</code></summary>
|
||||||
|
|
||||||
Executes a tool directly with the provided parameters. This is useful for testing tool configurations and parameters without needing a full client setup.
|
Executes a tool directly with the provided parameters. This is useful for testing tool configurations and parameters without needing a full client setup.
|
||||||
|
|
||||||
@@ -42,8 +43,36 @@ Executes a tool directly with the provided parameters. This is useful for testin
|
|||||||
toolbox invoke <tool-name> [params]
|
toolbox invoke <tool-name> [params]
|
||||||
```
|
```
|
||||||
|
|
||||||
- `<tool-name>`: The name of the tool to execute (as defined in your configuration).
|
**Arguments:**
|
||||||
- `[params]`: (Optional) A JSON string containing the parameters for the tool.
|
|
||||||
|
- `tool-name`: The name of the tool to execute (as defined in your configuration).
|
||||||
|
- `params`: (Optional) A JSON string containing the parameters for the tool.
|
||||||
|
|
||||||
|
For more detailed instructions, see [Invoke Tools via CLI](../how-to/invoke_tool.md).
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><code>skills-generate</code></summary>
|
||||||
|
|
||||||
|
Generates a skill package from a specified toolset. Each tool in the toolset will have a corresponding Node.js execution script in the generated skill.
|
||||||
|
|
||||||
|
**Syntax:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
toolbox skills-generate --name <name> --description <description> --toolset <toolset> --output-dir <output>
|
||||||
|
```
|
||||||
|
|
||||||
|
**Flags:**
|
||||||
|
|
||||||
|
- `--name`: Name of the generated skill.
|
||||||
|
- `--description`: Description of the generated skill.
|
||||||
|
- `--toolset`: (Optional) Name of the toolset to convert into a skill. If not provided, all tools will be included.
|
||||||
|
- `--output-dir`: (Optional) Directory to output generated skills (default: "skills").
|
||||||
|
|
||||||
|
For more detailed instructions, see [Generate Agent Skills](../how-to/generate_skill.md).
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
|
|||||||
237
internal/cli/skills/command.go
Normal file
237
internal/cli/skills/command.go
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package skills
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
_ "embed"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/log"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/server"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RootCommand defines the interface for required by skills-generate subcommand.
|
||||||
|
// This allows subcommands to access shared resources and functionality without
|
||||||
|
// direct coupling to the root command's implementation.
|
||||||
|
type RootCommand interface {
|
||||||
|
// Config returns a copy of the current server configuration.
|
||||||
|
Config() server.ServerConfig
|
||||||
|
|
||||||
|
// LoadConfig loads and merges the configuration from files, folders, and prebuilts.
|
||||||
|
LoadConfig(ctx context.Context) error
|
||||||
|
|
||||||
|
// Setup initializes the runtime environment, including logging and telemetry.
|
||||||
|
// It returns the updated context and a shutdown function to be called when finished.
|
||||||
|
Setup(ctx context.Context) (context.Context, func(context.Context) error, error)
|
||||||
|
|
||||||
|
// Logger returns the logger instance.
|
||||||
|
Logger() log.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// Command is the command for generating skills.
|
||||||
|
type Command struct {
|
||||||
|
*cobra.Command
|
||||||
|
rootCmd RootCommand
|
||||||
|
name string
|
||||||
|
description string
|
||||||
|
toolset string
|
||||||
|
outputDir string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCommand creates a new Command.
|
||||||
|
func NewCommand(rootCmd RootCommand) *cobra.Command {
|
||||||
|
cmd := &Command{
|
||||||
|
rootCmd: rootCmd,
|
||||||
|
}
|
||||||
|
cmd.Command = &cobra.Command{
|
||||||
|
Use: "skills-generate",
|
||||||
|
Short: "Generate skills from tool configurations",
|
||||||
|
RunE: func(c *cobra.Command, args []string) error {
|
||||||
|
return cmd.run(c)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Flags().StringVar(&cmd.name, "name", "", "Name of the generated skill.")
|
||||||
|
cmd.Flags().StringVar(&cmd.description, "description", "", "Description of the generated skill")
|
||||||
|
cmd.Flags().StringVar(&cmd.toolset, "toolset", "", "Name of the toolset to convert into a skill. If not provided, all tools will be included.")
|
||||||
|
cmd.Flags().StringVar(&cmd.outputDir, "output-dir", "skills", "Directory to output generated skills")
|
||||||
|
|
||||||
|
_ = cmd.MarkFlagRequired("name")
|
||||||
|
_ = cmd.MarkFlagRequired("description")
|
||||||
|
return cmd.Command
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Command) run(cmd *cobra.Command) error {
|
||||||
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
ctx, shutdown, err := c.rootCmd.Setup(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = shutdown(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
logger := c.rootCmd.Logger()
|
||||||
|
|
||||||
|
// Load and merge tool configurations
|
||||||
|
if err := c.rootCmd.LoadConfig(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(c.outputDir, 0755); err != nil {
|
||||||
|
errMsg := fmt.Errorf("error creating output directory: %w", err)
|
||||||
|
logger.ErrorContext(ctx, errMsg.Error())
|
||||||
|
return errMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.InfoContext(ctx, fmt.Sprintf("Generating skill '%s'...", c.name))
|
||||||
|
|
||||||
|
// Initialize toolbox and collect tools
|
||||||
|
allTools, err := c.collectTools(ctx)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := fmt.Errorf("error collecting tools: %w", err)
|
||||||
|
logger.ErrorContext(ctx, errMsg.Error())
|
||||||
|
return errMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(allTools) == 0 {
|
||||||
|
logger.InfoContext(ctx, "No tools found to generate.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate the combined skill directory
|
||||||
|
skillPath := filepath.Join(c.outputDir, c.name)
|
||||||
|
if err := os.MkdirAll(skillPath, 0755); err != nil {
|
||||||
|
errMsg := fmt.Errorf("error creating skill directory: %w", err)
|
||||||
|
logger.ErrorContext(ctx, errMsg.Error())
|
||||||
|
return errMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate assets directory
|
||||||
|
assetsPath := filepath.Join(skillPath, "assets")
|
||||||
|
if err := os.MkdirAll(assetsPath, 0755); err != nil {
|
||||||
|
errMsg := fmt.Errorf("error creating assets dir: %w", err)
|
||||||
|
logger.ErrorContext(ctx, errMsg.Error())
|
||||||
|
return errMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate scripts directory
|
||||||
|
scriptsPath := filepath.Join(skillPath, "scripts")
|
||||||
|
if err := os.MkdirAll(scriptsPath, 0755); err != nil {
|
||||||
|
errMsg := fmt.Errorf("error creating scripts dir: %w", err)
|
||||||
|
logger.ErrorContext(ctx, errMsg.Error())
|
||||||
|
return errMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate over keys to ensure deterministic order
|
||||||
|
var toolNames []string
|
||||||
|
for name := range allTools {
|
||||||
|
toolNames = append(toolNames, name)
|
||||||
|
}
|
||||||
|
sort.Strings(toolNames)
|
||||||
|
|
||||||
|
for _, toolName := range toolNames {
|
||||||
|
// Generate YAML config in asset directory
|
||||||
|
minimizedContent, err := generateToolConfigYAML(c.rootCmd.Config(), toolName)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := fmt.Errorf("error generating filtered config for %s: %w", toolName, err)
|
||||||
|
logger.ErrorContext(ctx, errMsg.Error())
|
||||||
|
return errMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
specificToolsFileName := fmt.Sprintf("%s.yaml", toolName)
|
||||||
|
if minimizedContent != nil {
|
||||||
|
destPath := filepath.Join(assetsPath, specificToolsFileName)
|
||||||
|
if err := os.WriteFile(destPath, minimizedContent, 0644); err != nil {
|
||||||
|
errMsg := fmt.Errorf("error writing filtered config for %s: %w", toolName, err)
|
||||||
|
logger.ErrorContext(ctx, errMsg.Error())
|
||||||
|
return errMsg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate wrapper script in scripts directory
|
||||||
|
scriptContent, err := generateScriptContent(toolName, specificToolsFileName)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := fmt.Errorf("error generating script content for %s: %w", toolName, err)
|
||||||
|
logger.ErrorContext(ctx, errMsg.Error())
|
||||||
|
return errMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
scriptFilename := filepath.Join(scriptsPath, fmt.Sprintf("%s.js", toolName))
|
||||||
|
if err := os.WriteFile(scriptFilename, []byte(scriptContent), 0755); err != nil {
|
||||||
|
errMsg := fmt.Errorf("error writing script %s: %w", scriptFilename, err)
|
||||||
|
logger.ErrorContext(ctx, errMsg.Error())
|
||||||
|
return errMsg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate SKILL.md
|
||||||
|
skillContent, err := generateSkillMarkdown(c.name, c.description, allTools)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := fmt.Errorf("error generating SKILL.md content: %w", err)
|
||||||
|
logger.ErrorContext(ctx, errMsg.Error())
|
||||||
|
return errMsg
|
||||||
|
}
|
||||||
|
skillMdPath := filepath.Join(skillPath, "SKILL.md")
|
||||||
|
if err := os.WriteFile(skillMdPath, []byte(skillContent), 0644); err != nil {
|
||||||
|
errMsg := fmt.Errorf("error writing SKILL.md: %w", err)
|
||||||
|
logger.ErrorContext(ctx, errMsg.Error())
|
||||||
|
return errMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.InfoContext(ctx, fmt.Sprintf("Successfully generated skill '%s' with %d tools.", c.name, len(allTools)))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Command) collectTools(ctx context.Context) (map[string]tools.Tool, error) {
|
||||||
|
// Initialize Resources
|
||||||
|
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, c.rootCmd.Config())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to initialize resources: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resourceMgr := resources.NewResourceManager(sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
|
||||||
|
|
||||||
|
result := make(map[string]tools.Tool)
|
||||||
|
|
||||||
|
if c.toolset == "" {
|
||||||
|
return toolsMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ts, ok := resourceMgr.GetToolset(c.toolset)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("toolset %q not found", c.toolset)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range ts.Tools {
|
||||||
|
if t != nil {
|
||||||
|
tool := *t
|
||||||
|
result[tool.McpManifest().Name] = tool
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
297
internal/cli/skills/generator.go
Normal file
297
internal/cli/skills/generator.go
Normal file
@@ -0,0 +1,297 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package skills
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
"github.com/goccy/go-yaml"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/server"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
)
|
||||||
|
|
||||||
|
const skillTemplate = `---
|
||||||
|
name: {{.SkillName}}
|
||||||
|
description: {{.SkillDescription}}
|
||||||
|
---
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
All scripts can be executed using Node.js. Replace ` + "`" + `<param_name>` + "`" + ` and ` + "`" + `<param_value>` + "`" + ` with actual values.
|
||||||
|
|
||||||
|
**Bash:**
|
||||||
|
` + "`" + `node scripts/<script_name>.js '{"<param_name>": "<param_value>"}'` + "`" + `
|
||||||
|
|
||||||
|
**PowerShell:**
|
||||||
|
` + "`" + `node scripts/<script_name>.js '{\"<param_name>\": \"<param_value>\"}'` + "`" + `
|
||||||
|
|
||||||
|
## Scripts
|
||||||
|
|
||||||
|
{{range .Tools}}
|
||||||
|
### {{.Name}}
|
||||||
|
|
||||||
|
{{.Description}}
|
||||||
|
|
||||||
|
{{.ParametersSchema}}
|
||||||
|
|
||||||
|
---
|
||||||
|
{{end}}
|
||||||
|
`
|
||||||
|
|
||||||
|
type toolTemplateData struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
ParametersSchema string
|
||||||
|
}
|
||||||
|
|
||||||
|
type skillTemplateData struct {
|
||||||
|
SkillName string
|
||||||
|
SkillDescription string
|
||||||
|
Tools []toolTemplateData
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateSkillMarkdown generates the content of the SKILL.md file.
|
||||||
|
// It includes usage instructions and a reference section for each tool in the skill,
|
||||||
|
// detailing its description and parameters.
|
||||||
|
func generateSkillMarkdown(skillName, skillDescription string, toolsMap map[string]tools.Tool) (string, error) {
|
||||||
|
var toolsData []toolTemplateData
|
||||||
|
|
||||||
|
// Order tools based on name
|
||||||
|
var toolNames []string
|
||||||
|
for name := range toolsMap {
|
||||||
|
toolNames = append(toolNames, name)
|
||||||
|
}
|
||||||
|
sort.Strings(toolNames)
|
||||||
|
|
||||||
|
for _, name := range toolNames {
|
||||||
|
tool := toolsMap[name]
|
||||||
|
manifest := tool.Manifest()
|
||||||
|
|
||||||
|
parametersSchema, err := formatParameters(manifest.Parameters)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
toolsData = append(toolsData, toolTemplateData{
|
||||||
|
Name: name,
|
||||||
|
Description: manifest.Description,
|
||||||
|
ParametersSchema: parametersSchema,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
data := skillTemplateData{
|
||||||
|
SkillName: skillName,
|
||||||
|
SkillDescription: skillDescription,
|
||||||
|
Tools: toolsData,
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpl, err := template.New("markdown").Parse(skillTemplate)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error parsing markdown template: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf strings.Builder
|
||||||
|
if err := tmpl.Execute(&buf, data); err != nil {
|
||||||
|
return "", fmt.Errorf("error executing markdown template: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const nodeScriptTemplate = `#!/usr/bin/env node
|
||||||
|
|
||||||
|
const { spawn, execSync } = require('child_process');
|
||||||
|
const path = require('path');
|
||||||
|
const fs = require('fs');
|
||||||
|
|
||||||
|
const toolName = "{{.Name}}";
|
||||||
|
const toolsFileName = "{{.ToolsFileName}}";
|
||||||
|
|
||||||
|
function getToolboxPath() {
|
||||||
|
try {
|
||||||
|
const checkCommand = process.platform === 'win32' ? 'where toolbox' : 'which toolbox';
|
||||||
|
const globalPath = execSync(checkCommand, { stdio: 'pipe', encoding: 'utf-8' }).trim();
|
||||||
|
if (globalPath) {
|
||||||
|
return globalPath.split('\n')[0].trim();
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
// Ignore error;
|
||||||
|
}
|
||||||
|
const localPath = path.resolve(__dirname, '../../../toolbox');
|
||||||
|
if (fs.existsSync(localPath)) {
|
||||||
|
return localPath;
|
||||||
|
}
|
||||||
|
throw new Error("Toolbox binary not found");
|
||||||
|
}
|
||||||
|
|
||||||
|
let toolboxBinary;
|
||||||
|
try {
|
||||||
|
toolboxBinary = getToolboxPath();
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Error:", err.message);
|
||||||
|
process.exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
let configArgs = [];
|
||||||
|
if (toolsFileName) {
|
||||||
|
configArgs.push("--tools-file", path.join(__dirname, "..", "assets", toolsFileName));
|
||||||
|
}
|
||||||
|
|
||||||
|
const args = process.argv.slice(2);
|
||||||
|
const toolboxArgs = [...configArgs, "invoke", toolName, ...args];
|
||||||
|
|
||||||
|
const child = spawn(toolboxBinary, toolboxArgs, { stdio: 'inherit' });
|
||||||
|
|
||||||
|
child.on('close', (code) => {
|
||||||
|
process.exit(code);
|
||||||
|
});
|
||||||
|
|
||||||
|
child.on('error', (err) => {
|
||||||
|
console.error("Error executing toolbox:", err);
|
||||||
|
process.exit(1);
|
||||||
|
});
|
||||||
|
`
|
||||||
|
|
||||||
|
type scriptData struct {
|
||||||
|
Name string
|
||||||
|
ToolsFileName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateScriptContent creates the content for a Node.js wrapper script.
|
||||||
|
// This script invokes the toolbox CLI with the appropriate configuration
|
||||||
|
// (using a generated tools file) and arguments to execute the specific tool.
|
||||||
|
func generateScriptContent(name string, toolsFileName string) (string, error) {
|
||||||
|
data := scriptData{
|
||||||
|
Name: name,
|
||||||
|
ToolsFileName: toolsFileName,
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpl, err := template.New("script").Parse(nodeScriptTemplate)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error parsing script template: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf strings.Builder
|
||||||
|
if err := tmpl.Execute(&buf, data); err != nil {
|
||||||
|
return "", fmt.Errorf("error executing script template: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatParameters converts a list of parameter manifests into a formatted JSON schema string.
|
||||||
|
// This schema is used in the skill documentation to describe the input parameters for a tool.
|
||||||
|
func formatParameters(params []parameters.ParameterManifest) (string, error) {
|
||||||
|
if len(params) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
properties := make(map[string]interface{})
|
||||||
|
var required []string
|
||||||
|
|
||||||
|
for _, p := range params {
|
||||||
|
paramMap := map[string]interface{}{
|
||||||
|
"type": p.Type,
|
||||||
|
"description": p.Description,
|
||||||
|
}
|
||||||
|
if p.Default != nil {
|
||||||
|
paramMap["default"] = p.Default
|
||||||
|
}
|
||||||
|
properties[p.Name] = paramMap
|
||||||
|
if p.Required {
|
||||||
|
required = append(required, p.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
schema := map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": properties,
|
||||||
|
}
|
||||||
|
if len(required) > 0 {
|
||||||
|
schema["required"] = required
|
||||||
|
}
|
||||||
|
|
||||||
|
schemaJSON, err := json.MarshalIndent(schema, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error generating parameters schema: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("#### Parameters\n\n```json\n%s\n```", string(schemaJSON)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateToolConfigYAML generates the YAML configuration for a single tool and its dependency (source).
|
||||||
|
// It extracts the relevant tool and source configurations from the server config and formats them
|
||||||
|
// into a YAML document suitable for inclusion in the skill's assets.
|
||||||
|
func generateToolConfigYAML(cfg server.ServerConfig, toolName string) ([]byte, error) {
|
||||||
|
toolCfg, ok := cfg.ToolConfigs[toolName]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("error finding tool config: %s", toolName)
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
encoder := yaml.NewEncoder(&buf)
|
||||||
|
|
||||||
|
// Process Tool Config
|
||||||
|
toolWrapper := struct {
|
||||||
|
Kind string `yaml:"kind"`
|
||||||
|
Name string `yaml:"name"`
|
||||||
|
Config tools.ToolConfig `yaml:",inline"`
|
||||||
|
}{
|
||||||
|
Kind: "tools",
|
||||||
|
Name: toolName,
|
||||||
|
Config: toolCfg,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := encoder.Encode(toolWrapper); err != nil {
|
||||||
|
return nil, fmt.Errorf("error encoding tool config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process Source Config
|
||||||
|
var toolMap map[string]interface{}
|
||||||
|
b, err := yaml.Marshal(toolCfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error marshaling tool config: %w", err)
|
||||||
|
}
|
||||||
|
if err := yaml.Unmarshal(b, &toolMap); err != nil {
|
||||||
|
return nil, fmt.Errorf("error unmarshaling tool config map: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sourceName, ok := toolMap["source"].(string); ok {
|
||||||
|
if sourceCfg, ok := cfg.SourceConfigs[sourceName]; ok {
|
||||||
|
sourceWrapper := struct {
|
||||||
|
Kind string `yaml:"kind"`
|
||||||
|
Name string `yaml:"name"`
|
||||||
|
Config sources.SourceConfig `yaml:",inline"`
|
||||||
|
}{
|
||||||
|
Kind: "sources",
|
||||||
|
Name: sourceName,
|
||||||
|
Config: sourceCfg,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := encoder.Encode(sourceWrapper); err != nil {
|
||||||
|
return nil, fmt.Errorf("error encoding source config: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
338
internal/cli/skills/generator_test.go
Normal file
338
internal/cli/skills/generator_test.go
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package skills
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/server"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
"go.opentelemetry.io/otel/trace"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MockToolConfig struct {
|
||||||
|
Type string `yaml:"type"`
|
||||||
|
Source string `yaml:"source"`
|
||||||
|
Other string `yaml:"other"`
|
||||||
|
Parameters parameters.Parameters `yaml:"parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MockToolConfig) ToolConfigType() string {
|
||||||
|
return m.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MockToolConfig) Initialize(map[string]sources.Source) (tools.Tool, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockSourceConfig struct {
|
||||||
|
TypeVal string
|
||||||
|
ConnVal string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MockSourceConfig) SourceConfigType() string {
|
||||||
|
return m.TypeVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MockSourceConfig) Initialize(context.Context, trace.Tracer) (sources.Source, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MockSourceConfig) MarshalYAML() (interface{}, error) {
|
||||||
|
return map[string]interface{}{
|
||||||
|
"connection_string": m.ConnVal,
|
||||||
|
"type": m.TypeVal,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatParameters(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
params []parameters.ParameterManifest
|
||||||
|
wantContains []string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty parameters",
|
||||||
|
params: []parameters.ParameterManifest{},
|
||||||
|
wantContains: []string{""},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single required string parameter",
|
||||||
|
params: []parameters.ParameterManifest{
|
||||||
|
{
|
||||||
|
Name: "param1",
|
||||||
|
Description: "A test parameter",
|
||||||
|
Type: "string",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantContains: []string{
|
||||||
|
"## Parameters",
|
||||||
|
"```json",
|
||||||
|
`"type": "object"`,
|
||||||
|
`"properties": {`,
|
||||||
|
`"param1": {`,
|
||||||
|
`"type": "string"`,
|
||||||
|
`"description": "A test parameter"`,
|
||||||
|
`"required": [`,
|
||||||
|
`"param1"`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed parameters with defaults",
|
||||||
|
params: []parameters.ParameterManifest{
|
||||||
|
{
|
||||||
|
Name: "param1",
|
||||||
|
Description: "Param 1",
|
||||||
|
Type: "string",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "param2",
|
||||||
|
Description: "Param 2",
|
||||||
|
Type: "integer",
|
||||||
|
Default: 42,
|
||||||
|
Required: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantContains: []string{
|
||||||
|
`"param1": {`,
|
||||||
|
`"param2": {`,
|
||||||
|
`"default": 42`,
|
||||||
|
`"required": [`,
|
||||||
|
`"param1"`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := formatParameters(tt.params)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("formatParameters() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.wantErr {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tt.params) == 0 {
|
||||||
|
if got != "" {
|
||||||
|
t.Errorf("formatParameters() = %v, want empty string", got)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, want := range tt.wantContains {
|
||||||
|
if !strings.Contains(got, want) {
|
||||||
|
t.Errorf("formatParameters() result missing expected string: %s\nGot:\n%s", want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateSkillMarkdown(t *testing.T) {
|
||||||
|
toolsMap := map[string]tools.Tool{
|
||||||
|
"tool1": server.MockTool{
|
||||||
|
Description: "First tool",
|
||||||
|
Params: []parameters.Parameter{
|
||||||
|
parameters.NewStringParameter("p1", "d1"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := generateSkillMarkdown("MySkill", "My Description", toolsMap)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generateSkillMarkdown() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedSubstrings := []string{
|
||||||
|
"name: MySkill",
|
||||||
|
"description: My Description",
|
||||||
|
"## Usage",
|
||||||
|
"All scripts can be executed using Node.js",
|
||||||
|
"**Bash:**",
|
||||||
|
"`node scripts/<script_name>.js '{\"<param_name>\": \"<param_value>\"}'`",
|
||||||
|
"**PowerShell:**",
|
||||||
|
"`node scripts/<script_name>.js '{\"<param_name>\": \"<param_value>\"}'`",
|
||||||
|
"## Scripts",
|
||||||
|
"### tool1",
|
||||||
|
"First tool",
|
||||||
|
"## Parameters",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, s := range expectedSubstrings {
|
||||||
|
if !strings.Contains(got, s) {
|
||||||
|
t.Errorf("generateSkillMarkdown() missing substring %q", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateScriptContent(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
toolName string
|
||||||
|
toolsFileName string
|
||||||
|
wantContains []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic script",
|
||||||
|
toolName: "test-tool",
|
||||||
|
toolsFileName: "",
|
||||||
|
wantContains: []string{
|
||||||
|
`const toolName = "test-tool";`,
|
||||||
|
`const toolsFileName = "";`,
|
||||||
|
`const toolboxArgs = [...configArgs, "invoke", toolName, ...args];`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "script with tools file",
|
||||||
|
toolName: "complex-tool",
|
||||||
|
toolsFileName: "tools.yaml",
|
||||||
|
wantContains: []string{
|
||||||
|
`const toolName = "complex-tool";`,
|
||||||
|
`const toolsFileName = "tools.yaml";`,
|
||||||
|
`configArgs.push("--tools-file", path.join(__dirname, "..", "assets", toolsFileName));`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := generateScriptContent(tt.toolName, tt.toolsFileName)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generateScriptContent() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, s := range tt.wantContains {
|
||||||
|
if !strings.Contains(got, s) {
|
||||||
|
t.Errorf("generateScriptContent() missing substring %q\nGot:\n%s", s, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateToolConfigYAML(t *testing.T) {
|
||||||
|
cfg := server.ServerConfig{
|
||||||
|
ToolConfigs: server.ToolConfigs{
|
||||||
|
"tool1": MockToolConfig{
|
||||||
|
Type: "custom-tool",
|
||||||
|
Source: "src1",
|
||||||
|
Other: "foo",
|
||||||
|
},
|
||||||
|
"toolNoSource": MockToolConfig{
|
||||||
|
Type: "http",
|
||||||
|
},
|
||||||
|
"toolWithParams": MockToolConfig{
|
||||||
|
Type: "custom-tool",
|
||||||
|
Parameters: []parameters.Parameter{
|
||||||
|
parameters.NewStringParameter("param1", "desc1"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SourceConfigs: server.SourceConfigs{
|
||||||
|
"src1": MockSourceConfig{
|
||||||
|
TypeVal: "postgres",
|
||||||
|
ConnVal: "conn1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
toolName string
|
||||||
|
wantContains []string
|
||||||
|
wantErr bool
|
||||||
|
wantNil bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "tool with source",
|
||||||
|
toolName: "tool1",
|
||||||
|
wantContains: []string{
|
||||||
|
"kind: tools",
|
||||||
|
"name: tool1",
|
||||||
|
"type: custom-tool",
|
||||||
|
"source: src1",
|
||||||
|
"other: foo",
|
||||||
|
"---",
|
||||||
|
"kind: sources",
|
||||||
|
"name: src1",
|
||||||
|
"type: postgres",
|
||||||
|
"connection_string: conn1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool without source",
|
||||||
|
toolName: "toolNoSource",
|
||||||
|
wantContains: []string{
|
||||||
|
"kind: tools",
|
||||||
|
"name: toolNoSource",
|
||||||
|
"type: http",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool with parameters",
|
||||||
|
toolName: "toolWithParams",
|
||||||
|
wantContains: []string{
|
||||||
|
"kind: tools",
|
||||||
|
"name: toolWithParams",
|
||||||
|
"type: custom-tool",
|
||||||
|
"parameters:",
|
||||||
|
"- name: param1",
|
||||||
|
"type: string",
|
||||||
|
"description: desc1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-existent tool",
|
||||||
|
toolName: "missing-tool",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotBytes, err := generateToolConfigYAML(cfg, tt.toolName)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("generateToolConfigYAML() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.wantErr {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantNil {
|
||||||
|
if gotBytes != nil {
|
||||||
|
t.Errorf("generateToolConfigYAML() expected nil, got %s", string(gotBytes))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
got := string(gotBytes)
|
||||||
|
for _, want := range tt.wantContains {
|
||||||
|
if !strings.Contains(got, want) {
|
||||||
|
t.Errorf("generateToolConfigYAML() result missing expected string: %q\nGot:\n%s", want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
@@ -233,11 +234,10 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
params, err := parameters.ParseParams(tool.GetParameters(), data, claimsFromAuth)
|
params, err := parameters.ParseParams(tool.GetParameters(), data, claimsFromAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If auth error, return 401 or 403
|
// If auth error, return 401
|
||||||
var clientServerErr *util.ClientServerError
|
if errors.Is(err, util.ErrUnauthorized) {
|
||||||
if errors.As(err, &clientServerErr) && (clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden) {
|
|
||||||
s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err))
|
s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err))
|
||||||
_ = render.Render(w, r, newErrResponse(err, clientServerErr.Code))
|
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = fmt.Errorf("provided parameters were invalid: %w", err)
|
err = fmt.Errorf("provided parameters were invalid: %w", err)
|
||||||
@@ -259,50 +259,35 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// Determine what error to return to the users.
|
// Determine what error to return to the users.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var tbErr util.ToolboxError
|
errStr := err.Error()
|
||||||
|
var statusCode int
|
||||||
|
|
||||||
if errors.As(err, &tbErr) {
|
// Upstream API auth error propagation
|
||||||
switch tbErr.Category() {
|
switch {
|
||||||
case util.CategoryAgent:
|
case strings.Contains(errStr, "Error 401"):
|
||||||
// Agent Errors -> 200 OK
|
statusCode = http.StatusUnauthorized
|
||||||
s.logger.DebugContext(ctx, fmt.Sprintf("Tool invocation agent error: %v", err))
|
case strings.Contains(errStr, "Error 403"):
|
||||||
_ = render.Render(w, r, newErrResponse(err, http.StatusOK))
|
statusCode = http.StatusForbidden
|
||||||
return
|
|
||||||
|
|
||||||
case util.CategoryServer:
|
|
||||||
// Server Errors -> Check the specific code inside
|
|
||||||
var clientServerErr *util.ClientServerError
|
|
||||||
statusCode := http.StatusInternalServerError // Default to 500
|
|
||||||
|
|
||||||
if errors.As(err, &clientServerErr) {
|
|
||||||
if clientServerErr.Code != 0 {
|
|
||||||
statusCode = clientServerErr.Code
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process auth error
|
|
||||||
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
// Token error, pass through 401/403
|
// Propagate the original 401/403 error.
|
||||||
s.logger.DebugContext(ctx, fmt.Sprintf("Client credentials lack authorization: %v", err))
|
s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err))
|
||||||
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// ADC/Config error, return 500
|
// ADC lacking permission or credentials configuration error.
|
||||||
statusCode = http.StatusInternalServerError
|
internalErr := fmt.Errorf("unexpected auth error occured during Tool invocation: %w", err)
|
||||||
}
|
s.logger.ErrorContext(ctx, internalErr.Error())
|
||||||
|
_ = render.Render(w, r, newErrResponse(internalErr, http.StatusInternalServerError))
|
||||||
s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation server error: %v", err))
|
|
||||||
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
err = fmt.Errorf("error while invoking tool: %w", err)
|
||||||
// Unknown error -> 500
|
s.logger.DebugContext(ctx, err.Error())
|
||||||
s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation unknown error: %v", err))
|
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
||||||
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
resMarshal, err := json.Marshal(res)
|
resMarshal, err := json.Marshal(res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/log"
|
"github.com/googleapis/genai-toolbox/internal/log"
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||||
@@ -41,140 +40,6 @@ var (
|
|||||||
_ prompts.Prompt = MockPrompt{}
|
_ prompts.Prompt = MockPrompt{}
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockTool is used to mock tools in tests
|
|
||||||
type MockTool struct {
|
|
||||||
Name string
|
|
||||||
Description string
|
|
||||||
Params []parameters.Parameter
|
|
||||||
manifest tools.Manifest
|
|
||||||
unauthorized bool
|
|
||||||
requiresClientAuthrorization bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t MockTool) Invoke(context.Context, tools.SourceProvider, parameters.ParamValues, tools.AccessToken) (any, error) {
|
|
||||||
mock := []any{t.Name}
|
|
||||||
return mock, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t MockTool) ToConfig() tools.ToolConfig {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// claims is a map of user info decoded from an auth token
|
|
||||||
func (t MockTool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) {
|
|
||||||
return parameters.ParseParams(t.Params, data, claimsMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t MockTool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
|
||||||
return parameters.EmbedParams(ctx, t.Params, paramValues, embeddingModelsMap, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t MockTool) Manifest() tools.Manifest {
|
|
||||||
pMs := make([]parameters.ParameterManifest, 0, len(t.Params))
|
|
||||||
for _, p := range t.Params {
|
|
||||||
pMs = append(pMs, p.Manifest())
|
|
||||||
}
|
|
||||||
return tools.Manifest{Description: t.Description, Parameters: pMs}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t MockTool) Authorized(verifiedAuthServices []string) bool {
|
|
||||||
// defaulted to true
|
|
||||||
return !t.unauthorized
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) (bool, error) {
|
|
||||||
// defaulted to false
|
|
||||||
return t.requiresClientAuthrorization, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t MockTool) GetParameters() parameters.Parameters {
|
|
||||||
return t.Params
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t MockTool) McpManifest() tools.McpManifest {
|
|
||||||
properties := make(map[string]parameters.ParameterMcpManifest)
|
|
||||||
required := make([]string, 0)
|
|
||||||
authParams := make(map[string][]string)
|
|
||||||
|
|
||||||
for _, p := range t.Params {
|
|
||||||
name := p.GetName()
|
|
||||||
paramManifest, authParamList := p.McpManifest()
|
|
||||||
properties[name] = paramManifest
|
|
||||||
required = append(required, name)
|
|
||||||
|
|
||||||
if len(authParamList) > 0 {
|
|
||||||
authParams[name] = authParamList
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
toolsSchema := parameters.McpToolsSchema{
|
|
||||||
Type: "object",
|
|
||||||
Properties: properties,
|
|
||||||
Required: required,
|
|
||||||
}
|
|
||||||
|
|
||||||
mcpManifest := tools.McpManifest{
|
|
||||||
Name: t.Name,
|
|
||||||
Description: t.Description,
|
|
||||||
InputSchema: toolsSchema,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(authParams) > 0 {
|
|
||||||
mcpManifest.Metadata = map[string]any{
|
|
||||||
"toolbox/authParams": authParams,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return mcpManifest
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t MockTool) GetAuthTokenHeaderName(tools.SourceProvider) (string, error) {
|
|
||||||
return "Authorization", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockPrompt is used to mock prompts in tests
|
|
||||||
type MockPrompt struct {
|
|
||||||
Name string
|
|
||||||
Description string
|
|
||||||
Args prompts.Arguments
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p MockPrompt) SubstituteParams(vals parameters.ParamValues) (any, error) {
|
|
||||||
return []prompts.Message{
|
|
||||||
{
|
|
||||||
Role: "user",
|
|
||||||
Content: fmt.Sprintf("substituted %s", p.Name),
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p MockPrompt) ParseArgs(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) {
|
|
||||||
var params parameters.Parameters
|
|
||||||
for _, arg := range p.Args {
|
|
||||||
params = append(params, arg.Parameter)
|
|
||||||
}
|
|
||||||
return parameters.ParseParams(params, data, claimsMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p MockPrompt) Manifest() prompts.Manifest {
|
|
||||||
var argManifests []parameters.ParameterManifest
|
|
||||||
for _, arg := range p.Args {
|
|
||||||
argManifests = append(argManifests, arg.Manifest())
|
|
||||||
}
|
|
||||||
return prompts.Manifest{
|
|
||||||
Description: p.Description,
|
|
||||||
Arguments: argManifests,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p MockPrompt) McpManifest() prompts.McpManifest {
|
|
||||||
return prompts.GetMcpManifest(p.Name, p.Description, p.Args)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p MockPrompt) ToConfig() prompts.PromptConfig {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var tool1 = MockTool{
|
var tool1 = MockTool{
|
||||||
Name: "no_params",
|
Name: "no_params",
|
||||||
Params: []parameters.Parameter{},
|
Params: []parameters.Parameter{},
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -443,12 +444,15 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
code := rpcResponse.Error.Code
|
code := rpcResponse.Error.Code
|
||||||
switch code {
|
switch code {
|
||||||
case jsonrpc.INTERNAL_ERROR:
|
case jsonrpc.INTERNAL_ERROR:
|
||||||
// Map Internal RPC Error (-32603) to HTTP 500
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
case jsonrpc.INVALID_REQUEST:
|
case jsonrpc.INVALID_REQUEST:
|
||||||
var clientServerErr *util.ClientServerError
|
errStr := err.Error()
|
||||||
if errors.As(err, &clientServerErr) {
|
if errors.Is(err, util.ErrUnauthorized) {
|
||||||
w.WriteHeader(clientServerErr.Code)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
} else if strings.Contains(errStr, "Error 401") {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
} else if strings.Contains(errStr, "Error 403") {
|
||||||
|
w.WriteHeader(http.StatusForbidden)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -123,12 +124,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
errMsg := "missing access token in the 'Authorization' header"
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, errMsg, nil), util.NewClientServerError(
|
|
||||||
errMsg,
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,11 +172,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = util.NewClientServerError(
|
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||||
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -202,13 +194,21 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var tbErr util.ToolboxError
|
errStr := err.Error()
|
||||||
|
// Missing authService tokens.
|
||||||
|
if errors.Is(err, util.ErrUnauthorized) {
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
// Upstream auth error
|
||||||
|
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||||
|
if clientAuth {
|
||||||
|
// Error with client credentials should pass down to the client
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
// Auth error with ADC should raise internal 500 error
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
|
||||||
if errors.As(err, &tbErr) {
|
|
||||||
switch tbErr.Category() {
|
|
||||||
case util.CategoryAgent:
|
|
||||||
// MCP - Tool execution error
|
|
||||||
// Return SUCCESS but with IsError: true
|
|
||||||
text := TextContent{
|
text := TextContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: err.Error(),
|
Text: err.Error(),
|
||||||
@@ -218,28 +218,6 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
Id: id,
|
Id: id,
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
case util.CategoryServer:
|
|
||||||
// MCP Spec - Protocol error
|
|
||||||
// Return JSON-RPC ERROR
|
|
||||||
var clientServerErr *util.ClientServerError
|
|
||||||
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
|
||||||
|
|
||||||
if errors.As(err, &clientServerErr) {
|
|
||||||
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
|
||||||
if clientAuth {
|
|
||||||
rpcCode = jsonrpc.INVALID_REQUEST
|
|
||||||
} else {
|
|
||||||
rpcCode = jsonrpc.INTERNAL_ERROR
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Unknown error -> 500
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -123,12 +124,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
errMsg := "missing access token in the 'Authorization' header"
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, errMsg, nil), util.NewClientServerError(
|
|
||||||
errMsg,
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,11 +172,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = util.NewClientServerError(
|
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||||
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -202,13 +194,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var tbErr util.ToolboxError
|
errStr := err.Error()
|
||||||
|
// Missing authService tokens.
|
||||||
if errors.As(err, &tbErr) {
|
if errors.Is(err, util.ErrUnauthorized) {
|
||||||
switch tbErr.Category() {
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
case util.CategoryAgent:
|
}
|
||||||
// MCP - Tool execution error
|
// Upstream auth error
|
||||||
// Return SUCCESS but with IsError: true
|
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||||
|
if clientAuth {
|
||||||
|
// Error with client credentials should pass down to the client
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
// Auth error with ADC should raise internal 500 error
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
text := TextContent{
|
text := TextContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: err.Error(),
|
Text: err.Error(),
|
||||||
@@ -218,29 +217,8 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
Id: id,
|
Id: id,
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
}, nil
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
case util.CategoryServer:
|
|
||||||
// MCP Spec - Protocol error
|
|
||||||
// Return JSON-RPC ERROR
|
|
||||||
var clientServerErr *util.ClientServerError
|
|
||||||
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
|
||||||
|
|
||||||
if errors.As(err, &clientServerErr) {
|
|
||||||
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
|
||||||
if clientAuth {
|
|
||||||
rpcCode = jsonrpc.INVALID_REQUEST
|
|
||||||
} else {
|
|
||||||
rpcCode = jsonrpc.INTERNAL_ERROR
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Unknown error -> 500
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|
||||||
sliceRes, ok := results.([]any)
|
sliceRes, ok := results.([]any)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -116,12 +117,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
errMsg := "missing access token in the 'Authorization' header"
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, errMsg, nil), util.NewClientServerError(
|
|
||||||
errMsg,
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,11 +165,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = util.NewClientServerError(
|
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||||
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -195,13 +187,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var tbErr util.ToolboxError
|
errStr := err.Error()
|
||||||
|
// Missing authService tokens.
|
||||||
if errors.As(err, &tbErr) {
|
if errors.Is(err, util.ErrUnauthorized) {
|
||||||
switch tbErr.Category() {
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
case util.CategoryAgent:
|
}
|
||||||
// MCP - Tool execution error
|
// Upstream auth error
|
||||||
// Return SUCCESS but with IsError: true
|
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||||
|
if clientAuth {
|
||||||
|
// Error with client credentials should pass down to the client
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
// Auth error with ADC should raise internal 500 error
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
text := TextContent{
|
text := TextContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: err.Error(),
|
Text: err.Error(),
|
||||||
@@ -211,28 +210,6 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
Id: id,
|
Id: id,
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
case util.CategoryServer:
|
|
||||||
// MCP Spec - Protocol error
|
|
||||||
// Return JSON-RPC ERROR
|
|
||||||
var clientServerErr *util.ClientServerError
|
|
||||||
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
|
||||||
|
|
||||||
if errors.As(err, &clientServerErr) {
|
|
||||||
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
|
||||||
if clientAuth {
|
|
||||||
rpcCode = jsonrpc.INVALID_REQUEST
|
|
||||||
} else {
|
|
||||||
rpcCode = jsonrpc.INTERNAL_ERROR
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Unknown error -> 500
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -116,12 +117,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
errMsg := "missing access token in the 'Authorization' header"
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, errMsg, nil), util.NewClientServerError(
|
|
||||||
errMsg,
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,11 +165,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = util.NewClientServerError(
|
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||||
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -195,13 +187,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var tbErr util.ToolboxError
|
errStr := err.Error()
|
||||||
|
// Missing authService tokens.
|
||||||
if errors.As(err, &tbErr) {
|
if errors.Is(err, util.ErrUnauthorized) {
|
||||||
switch tbErr.Category() {
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
case util.CategoryAgent:
|
}
|
||||||
// MCP - Tool execution error
|
// Upstream auth error
|
||||||
// Return SUCCESS but with IsError: true
|
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||||
|
if clientAuth {
|
||||||
|
// Error with client credentials should pass down to the client
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
// Auth error with ADC should raise internal 500 error
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
text := TextContent{
|
text := TextContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: err.Error(),
|
Text: err.Error(),
|
||||||
@@ -211,28 +210,6 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
Id: id,
|
Id: id,
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
case util.CategoryServer:
|
|
||||||
// MCP Spec - Protocol error
|
|
||||||
// Return JSON-RPC ERROR
|
|
||||||
var clientServerErr *util.ClientServerError
|
|
||||||
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
|
||||||
|
|
||||||
if errors.As(err, &clientServerErr) {
|
|
||||||
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
|
||||||
if clientAuth {
|
|
||||||
rpcCode = jsonrpc.INVALID_REQUEST
|
|
||||||
} else {
|
|
||||||
rpcCode = jsonrpc.INTERNAL_ERROR
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Unknown error -> 500
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ func TestMcpEndpointWithoutInitialized(t *testing.T) {
|
|||||||
"id": "tools-call-tool4",
|
"id": "tools-call-tool4",
|
||||||
"error": map[string]any{
|
"error": map[string]any{
|
||||||
"code": -32600.0,
|
"code": -32600.0,
|
||||||
"message": "unauthorized Tool call: Please make sure you specify correct auth headers: <nil>",
|
"message": "unauthorized Tool call: Please make sure your specify correct auth headers: unauthorized",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -834,7 +834,7 @@ func TestMcpEndpoint(t *testing.T) {
|
|||||||
"id": "tools-call-tool4",
|
"id": "tools-call-tool4",
|
||||||
"error": map[string]any{
|
"error": map[string]any{
|
||||||
"code": -32600.0,
|
"code": -32600.0,
|
||||||
"message": "unauthorized Tool call: Please make sure you specify correct auth headers: <nil>",
|
"message": "unauthorized Tool call: Please make sure your specify correct auth headers: unauthorized",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
159
internal/server/mocks.go
Normal file
159
internal/server/mocks.go
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockTool is used to mock tools in tests
|
||||||
|
type MockTool struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
Params []parameters.Parameter
|
||||||
|
manifest tools.Manifest
|
||||||
|
unauthorized bool
|
||||||
|
requiresClientAuthrorization bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MockTool) Invoke(context.Context, tools.SourceProvider, parameters.ParamValues, tools.AccessToken) (any, error) {
|
||||||
|
mock := []any{t.Name}
|
||||||
|
return mock, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MockTool) ToConfig() tools.ToolConfig {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// claims is a map of user info decoded from an auth token
|
||||||
|
func (t MockTool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
return parameters.ParseParams(t.Params, data, claimsMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MockTool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
return parameters.EmbedParams(ctx, t.Params, paramValues, embeddingModelsMap, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MockTool) Manifest() tools.Manifest {
|
||||||
|
pMs := make([]parameters.ParameterManifest, 0, len(t.Params))
|
||||||
|
for _, p := range t.Params {
|
||||||
|
pMs = append(pMs, p.Manifest())
|
||||||
|
}
|
||||||
|
return tools.Manifest{Description: t.Description, Parameters: pMs}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MockTool) Authorized(verifiedAuthServices []string) bool {
|
||||||
|
// defaulted to true
|
||||||
|
return !t.unauthorized
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) (bool, error) {
|
||||||
|
// defaulted to false
|
||||||
|
return t.requiresClientAuthrorization, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MockTool) GetParameters() parameters.Parameters {
|
||||||
|
return t.Params
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MockTool) McpManifest() tools.McpManifest {
|
||||||
|
properties := make(map[string]parameters.ParameterMcpManifest)
|
||||||
|
required := make([]string, 0)
|
||||||
|
authParams := make(map[string][]string)
|
||||||
|
|
||||||
|
for _, p := range t.Params {
|
||||||
|
name := p.GetName()
|
||||||
|
paramManifest, authParamList := p.McpManifest()
|
||||||
|
properties[name] = paramManifest
|
||||||
|
required = append(required, name)
|
||||||
|
|
||||||
|
if len(authParamList) > 0 {
|
||||||
|
authParams[name] = authParamList
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
toolsSchema := parameters.McpToolsSchema{
|
||||||
|
Type: "object",
|
||||||
|
Properties: properties,
|
||||||
|
Required: required,
|
||||||
|
}
|
||||||
|
|
||||||
|
mcpManifest := tools.McpManifest{
|
||||||
|
Name: t.Name,
|
||||||
|
Description: t.Description,
|
||||||
|
InputSchema: toolsSchema,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(authParams) > 0 {
|
||||||
|
mcpManifest.Metadata = map[string]any{
|
||||||
|
"toolbox/authParams": authParams,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return mcpManifest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MockTool) GetAuthTokenHeaderName(tools.SourceProvider) (string, error) {
|
||||||
|
return "Authorization", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockPrompt is used to mock prompts in tests
|
||||||
|
type MockPrompt struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
Args prompts.Arguments
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MockPrompt) SubstituteParams(vals parameters.ParamValues) (any, error) {
|
||||||
|
return []prompts.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: fmt.Sprintf("substituted %s", p.Name),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MockPrompt) ParseArgs(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
var params parameters.Parameters
|
||||||
|
for _, arg := range p.Args {
|
||||||
|
params = append(params, arg.Parameter)
|
||||||
|
}
|
||||||
|
return parameters.ParseParams(params, data, claimsMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MockPrompt) Manifest() prompts.Manifest {
|
||||||
|
var argManifests []parameters.ParameterManifest
|
||||||
|
for _, arg := range p.Args {
|
||||||
|
argManifests = append(argManifests, arg.Manifest())
|
||||||
|
}
|
||||||
|
return prompts.Manifest{
|
||||||
|
Description: p.Description,
|
||||||
|
Arguments: argManifests,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MockPrompt) McpManifest() prompts.McpManifest {
|
||||||
|
return prompts.GetMcpManifest(p.Name, p.Description, p.Args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MockPrompt) ToConfig() prompts.PromptConfig {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -184,7 +184,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if source.UseClientAuthorization() {
|
if source.UseClientAuthorization() {
|
||||||
// Use client-side access token
|
// Use client-side access token
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return nil, util.NewClientServerError("tool is configured for client OAuth but no token was provided in the request header", http.StatusUnauthorized, nil)
|
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", util.ErrUnauthorized)
|
||||||
}
|
}
|
||||||
tokenStr, err = accessToken.ParseBearerToken()
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ package tools
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -81,7 +80,7 @@ type AccessToken string
|
|||||||
func (token AccessToken) ParseBearerToken() (string, error) {
|
func (token AccessToken) ParseBearerToken() (string, error) {
|
||||||
headerParts := strings.Split(string(token), " ")
|
headerParts := strings.Split(string(token), " ")
|
||||||
if len(headerParts) != 2 || strings.ToLower(headerParts[0]) != "bearer" {
|
if len(headerParts) != 2 || strings.ToLower(headerParts[0]) != "bearer" {
|
||||||
return "", util.NewClientServerError("authorization header must be in the format 'Bearer <token>'", http.StatusUnauthorized, nil)
|
return "", fmt.Errorf("authorization header must be in the format 'Bearer <token>': %w", util.ErrUnauthorized)
|
||||||
}
|
}
|
||||||
return headerParts[1], nil
|
return headerParts[1], nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,61 +0,0 @@
|
|||||||
// Copyright 2026 Google LLC
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
package util
|
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
type ErrorCategory string
|
|
||||||
|
|
||||||
const (
|
|
||||||
CategoryAgent ErrorCategory = "AGENT_ERROR"
|
|
||||||
CategoryServer ErrorCategory = "SERVER_ERROR"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ToolboxError is the interface all custom errors must satisfy
|
|
||||||
type ToolboxError interface {
|
|
||||||
error
|
|
||||||
Category() ErrorCategory
|
|
||||||
}
|
|
||||||
|
|
||||||
// Agent Errors return 200 to the sender
|
|
||||||
type AgentError struct {
|
|
||||||
Msg string
|
|
||||||
Cause error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *AgentError) Error() string { return e.Msg }
|
|
||||||
|
|
||||||
func (e *AgentError) Category() ErrorCategory { return CategoryAgent }
|
|
||||||
|
|
||||||
func (e *AgentError) Unwrap() error { return e.Cause }
|
|
||||||
|
|
||||||
func NewAgentError(msg string, cause error) *AgentError {
|
|
||||||
return &AgentError{Msg: msg, Cause: cause}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClientServerError returns 4XX/5XX error code
|
|
||||||
type ClientServerError struct {
|
|
||||||
Msg string
|
|
||||||
Code int
|
|
||||||
Cause error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *ClientServerError) Error() string { return fmt.Sprintf("%s: %v", e.Msg, e.Cause) }
|
|
||||||
|
|
||||||
func (e *ClientServerError) Category() ErrorCategory { return CategoryServer }
|
|
||||||
|
|
||||||
func (e *ClientServerError) Unwrap() error { return e.Cause }
|
|
||||||
|
|
||||||
func NewClientServerError(msg string, code int, cause error) *ClientServerError {
|
|
||||||
return &ClientServerError{Msg: msg, Code: code, Cause: cause}
|
|
||||||
}
|
|
||||||
@@ -19,7 +19,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -119,7 +118,7 @@ func parseFromAuthService(paramAuthServices []ParamAuthService, claimsMap map[st
|
|||||||
}
|
}
|
||||||
return v, nil
|
return v, nil
|
||||||
}
|
}
|
||||||
return nil, util.NewClientServerError("missing or invalid authentication header", http.StatusUnauthorized, nil)
|
return nil, fmt.Errorf("missing or invalid authentication header: %w", util.ErrUnauthorized)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckParamRequired checks if a parameter is required based on the required and default field.
|
// CheckParamRequired checks if a parameter is required based on the required and default field.
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -187,3 +188,5 @@ func InstrumentationFromContext(ctx context.Context) (*telemetry.Instrumentation
|
|||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unable to retrieve instrumentation")
|
return nil, fmt.Errorf("unable to retrieve instrumentation")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ErrUnauthorized = errors.New("unauthorized")
|
||||||
|
|||||||
Reference in New Issue
Block a user