mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-04 04:05:22 -05:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e884f52ea | ||
|
|
8a0f179f15 | ||
|
|
87ae5ae816 | ||
|
|
0c5285c5c8 | ||
|
|
ac544d0878 | ||
|
|
54f9a3d312 | ||
|
|
62d96a662d | ||
|
|
46244458c4 | ||
|
|
b6fa798610 | ||
|
|
bb58baff70 | ||
|
|
32b2c9366d |
@@ -35,7 +35,6 @@ import (
|
|||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||||
"github.com/googleapis/genai-toolbox/internal/cli/invoke"
|
"github.com/googleapis/genai-toolbox/internal/cli/invoke"
|
||||||
"github.com/googleapis/genai-toolbox/internal/cli/skills"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/log"
|
"github.com/googleapis/genai-toolbox/internal/log"
|
||||||
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
||||||
@@ -402,8 +401,6 @@ func NewCommand(opts ...Option) *Command {
|
|||||||
|
|
||||||
// Register subcommands for tool invocation
|
// Register subcommands for tool invocation
|
||||||
baseCmd.AddCommand(invoke.NewCommand(cmd))
|
baseCmd.AddCommand(invoke.NewCommand(cmd))
|
||||||
// Register subcommands for skill generation
|
|
||||||
baseCmd.AddCommand(skills.NewCommand(cmd))
|
|
||||||
|
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,179 +0,0 @@
|
|||||||
// Copyright 2026 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGenerateSkill(t *testing.T) {
|
|
||||||
// Create a temporary directory for tests
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
outputDir := filepath.Join(tmpDir, "skills")
|
|
||||||
|
|
||||||
// Create a tools.yaml file with a sqlite tool
|
|
||||||
toolsFileContent := `
|
|
||||||
sources:
|
|
||||||
my-sqlite:
|
|
||||||
kind: sqlite
|
|
||||||
database: test.db
|
|
||||||
tools:
|
|
||||||
hello-sqlite:
|
|
||||||
kind: sqlite-sql
|
|
||||||
source: my-sqlite
|
|
||||||
description: "hello tool"
|
|
||||||
statement: "SELECT 'hello' as greeting"
|
|
||||||
`
|
|
||||||
|
|
||||||
toolsFilePath := filepath.Join(tmpDir, "tools.yaml")
|
|
||||||
if err := os.WriteFile(toolsFilePath, []byte(toolsFileContent), 0644); err != nil {
|
|
||||||
t.Fatalf("failed to write tools file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
args := []string{
|
|
||||||
"skills-generate",
|
|
||||||
"--tools-file", toolsFilePath,
|
|
||||||
"--output-dir", outputDir,
|
|
||||||
"--name", "hello-sqlite",
|
|
||||||
"--description", "hello tool",
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
_, got, err := invokeCommandWithContext(ctx, args)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("command failed: %v\nOutput: %s", err, got)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify generated directory structure
|
|
||||||
skillPath := filepath.Join(outputDir, "hello-sqlite")
|
|
||||||
if _, err := os.Stat(skillPath); os.IsNotExist(err) {
|
|
||||||
t.Fatalf("skill directory not created: %s", skillPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check SKILL.md
|
|
||||||
skillMarkdown := filepath.Join(skillPath, "SKILL.md")
|
|
||||||
content, err := os.ReadFile(skillMarkdown)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to read SKILL.md: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedFrontmatter := `---
|
|
||||||
name: hello-sqlite
|
|
||||||
description: hello tool
|
|
||||||
---`
|
|
||||||
if !strings.HasPrefix(string(content), expectedFrontmatter) {
|
|
||||||
t.Errorf("SKILL.md does not have expected frontmatter format.\nExpected prefix:\n%s\nGot:\n%s", expectedFrontmatter, string(content))
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(string(content), "## Usage") {
|
|
||||||
t.Errorf("SKILL.md does not contain '## Usage' section")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(string(content), "## Scripts") {
|
|
||||||
t.Errorf("SKILL.md does not contain '## Scripts' section")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(string(content), "### hello-sqlite") {
|
|
||||||
t.Errorf("SKILL.md does not contain '### hello-sqlite' tool header")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check script file
|
|
||||||
scriptFilename := "hello-sqlite.js"
|
|
||||||
scriptPath := filepath.Join(skillPath, "scripts", scriptFilename)
|
|
||||||
if _, err := os.Stat(scriptPath); os.IsNotExist(err) {
|
|
||||||
t.Fatalf("script file not created: %s", scriptPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
scriptContent, err := os.ReadFile(scriptPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to read script file: %v", err)
|
|
||||||
}
|
|
||||||
if !strings.Contains(string(scriptContent), "hello-sqlite") {
|
|
||||||
t.Errorf("script file does not contain expected tool name")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check assets
|
|
||||||
assetPath := filepath.Join(skillPath, "assets", "hello-sqlite.yaml")
|
|
||||||
if _, err := os.Stat(assetPath); os.IsNotExist(err) {
|
|
||||||
t.Fatalf("asset file not created: %s", assetPath)
|
|
||||||
}
|
|
||||||
assetContent, err := os.ReadFile(assetPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to read asset file: %v", err)
|
|
||||||
}
|
|
||||||
if !strings.Contains(string(assetContent), "hello-sqlite") {
|
|
||||||
t.Errorf("asset file does not contain expected tool name")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateSkill_NoConfig(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
outputDir := filepath.Join(tmpDir, "skills")
|
|
||||||
|
|
||||||
args := []string{
|
|
||||||
"skills-generate",
|
|
||||||
"--output-dir", outputDir,
|
|
||||||
"--name", "test",
|
|
||||||
"--description", "test",
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, err := invokeCommandWithContext(context.Background(), args)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected command to fail when no configuration is provided and tools.yaml is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should not have created the directory if no config was processed
|
|
||||||
if _, err := os.Stat(outputDir); !os.IsNotExist(err) {
|
|
||||||
t.Errorf("output directory should not have been created")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateSkill_MissingArguments(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
toolsFilePath := filepath.Join(tmpDir, "tools.yaml")
|
|
||||||
if err := os.WriteFile(toolsFilePath, []byte("tools: {}"), 0644); err != nil {
|
|
||||||
t.Fatalf("failed to write tools file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args []string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "missing name",
|
|
||||||
args: []string{"skills-generate", "--tools-file", toolsFilePath, "--description", "test"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "missing description",
|
|
||||||
args: []string{"skills-generate", "--tools-file", toolsFilePath, "--name", "test"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
_, got, err := invokeCommandWithContext(context.Background(), tt.args)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("expected command to fail due to missing arguments, but it succeeded\nOutput: %s", got)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -53,7 +53,7 @@ export async function main() {
|
|||||||
|
|
||||||
for (const query of queries) {
|
for (const query of queries) {
|
||||||
conversationHistory.push({ role: "user", content: [{ text: query }] });
|
conversationHistory.push({ role: "user", content: [{ text: query }] });
|
||||||
let response = await ai.generate({
|
const response = await ai.generate({
|
||||||
messages: conversationHistory,
|
messages: conversationHistory,
|
||||||
tools: tools,
|
tools: tools,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,112 +0,0 @@
|
|||||||
---
|
|
||||||
title: "Generate Agent Skills"
|
|
||||||
type: docs
|
|
||||||
weight: 10
|
|
||||||
description: >
|
|
||||||
How to generate agent skills from a toolset.
|
|
||||||
---
|
|
||||||
|
|
||||||
The `skills-generate` command allows you to convert a **toolset** into an **Agent Skill**. A toolset is a collection of tools, and the generated skill will contain metadata and execution scripts for all tools within that toolset, complying with the [Agent Skill specification](https://agentskills.io/specification).
|
|
||||||
|
|
||||||
## Before you begin
|
|
||||||
|
|
||||||
1. Make sure you have the `toolbox` executable in your PATH.
|
|
||||||
2. Make sure you have [Node.js](https://nodejs.org/) installed on your system.
|
|
||||||
|
|
||||||
## Generating a Skill from a Toolset
|
|
||||||
|
|
||||||
A skill package consists of a `SKILL.md` file (with required YAML frontmatter) and a set of Node.js scripts. Each tool defined in your toolset maps to a corresponding script in the generated Node.js scripts (`.js`) that work across different platforms (Linux, macOS, Windows).
|
|
||||||
|
|
||||||
|
|
||||||
### Command Usage
|
|
||||||
|
|
||||||
The basic syntax for the command is:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
toolbox <tool-source> skills-generate \
|
|
||||||
--name <skill-name> \
|
|
||||||
--toolset <toolset-name> \
|
|
||||||
--description <description> \
|
|
||||||
--output-dir <output-directory>
|
|
||||||
```
|
|
||||||
|
|
||||||
- `<tool-source>`: Can be `--tools-file`, `--tools-files`, `--tools-folder`, and `--prebuilt`. See the [CLI Reference](../reference/cli.md) for details.
|
|
||||||
- `--name`: Name of the generated skill.
|
|
||||||
- `--description`: Description of the generated skill.
|
|
||||||
- `--toolset`: (Optional) Name of the toolset to convert into a skill. If not provided, all tools will be included.
|
|
||||||
- `--output-dir`: (Optional) Directory to output generated skills (default: "skills").
|
|
||||||
|
|
||||||
{{< notice note >}}
|
|
||||||
**Note:** The `<skill-name>` must follow the Agent Skill [naming convention](https://agentskills.io/specification): it must contain only lowercase alphanumeric characters and hyphens, cannot start or end with a hyphen, and cannot contain consecutive hyphens (e.g., `my-skill`, `data-processing`).
|
|
||||||
{{< /notice >}}
|
|
||||||
|
|
||||||
### Example: Custom Tools File
|
|
||||||
|
|
||||||
1. Create a `tools.yaml` file with a toolset and some tools:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
tools:
|
|
||||||
tool_a:
|
|
||||||
description: "First tool"
|
|
||||||
run:
|
|
||||||
command: "echo 'Tool A'"
|
|
||||||
tool_b:
|
|
||||||
description: "Second tool"
|
|
||||||
run:
|
|
||||||
command: "echo 'Tool B'"
|
|
||||||
toolsets:
|
|
||||||
my_toolset:
|
|
||||||
tools:
|
|
||||||
- tool_a
|
|
||||||
- tool_b
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Generate the skill:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
toolbox --tools-file tools.yaml skills-generate \
|
|
||||||
--name "my-skill" \
|
|
||||||
--toolset "my_toolset" \
|
|
||||||
--description "A skill containing multiple tools" \
|
|
||||||
--output-dir "generated-skills"
|
|
||||||
```
|
|
||||||
|
|
||||||
3. The generated skill directory structure:
|
|
||||||
|
|
||||||
```text
|
|
||||||
generated-skills/
|
|
||||||
└── my-skill/
|
|
||||||
├── SKILL.md
|
|
||||||
├── assets/
|
|
||||||
│ ├── tool_a.yaml
|
|
||||||
│ └── tool_b.yaml
|
|
||||||
└── scripts/
|
|
||||||
├── tool_a.js
|
|
||||||
└── tool_b.js
|
|
||||||
```
|
|
||||||
|
|
||||||
In this example, the skill contains two Node.js scripts (`tool_a.js` and `tool_b.js`), each mapping to a tool in the original toolset.
|
|
||||||
|
|
||||||
### Example: Prebuilt Configuration
|
|
||||||
|
|
||||||
You can also generate skills from prebuilt toolsets:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
toolbox --prebuilt alloydb-postgres-admin skills-generate \
|
|
||||||
--name "alloydb-postgres-admin" \
|
|
||||||
--description "skill for performing administrative operations on alloydb"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Installing the Generated Skill in Gemini CLI
|
|
||||||
|
|
||||||
Once you have generated a skill, you can install it into the Gemini CLI using the `gemini skills install` command.
|
|
||||||
|
|
||||||
### Installation Command
|
|
||||||
|
|
||||||
Provide the path to the directory containing the generated skill:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
gemini skills install /path/to/generated-skills/my-skill
|
|
||||||
```
|
|
||||||
|
|
||||||
Alternatively, use ~/.gemini/skills as the `--output-dir` to generate the skill straight to the Gemini CLI.
|
|
||||||
@@ -13,22 +13,21 @@ The `invoke` command allows you to invoke tools defined in your configuration di
|
|||||||
|
|
||||||
{{< notice tip >}}
|
{{< notice tip >}}
|
||||||
**Keep configurations minimal:** The `invoke` command initializes *all* resources (sources, tools, etc.) defined in your configuration files during execution. To ensure fast response times, consider using a minimal configuration file containing only the tools you need for the specific invocation.
|
**Keep configurations minimal:** The `invoke` command initializes *all* resources (sources, tools, etc.) defined in your configuration files during execution. To ensure fast response times, consider using a minimal configuration file containing only the tools you need for the specific invocation.
|
||||||
{{< /notice >}}
|
{{< notice tip >}}
|
||||||
|
|
||||||
## Before you begin
|
## Prerequisites
|
||||||
|
|
||||||
1. Make sure you have the `toolbox` binary installed or built.
|
- You have the `toolbox` binary installed or built.
|
||||||
2. Make sure you have a valid tool configuration file (e.g., `tools.yaml`).
|
- You have a valid tool configuration file (e.g., `tools.yaml`).
|
||||||
|
|
||||||
### Command Usage
|
## Basic Usage
|
||||||
|
|
||||||
The basic syntax for the command is:
|
The basic syntax for the command is:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
toolbox <tool-source> invoke <tool-name> [params]
|
toolbox [--tools-file <path> | --prebuilt <name>] invoke <tool-name> [params]
|
||||||
```
|
```
|
||||||
|
|
||||||
- `<tool-source>`: Can be `--tools-file`, `--tools-files`, `--tools-folder`, and `--prebuilt`. See the [CLI Reference](../reference/cli.md) for details.
|
|
||||||
- `<tool-name>`: The name of the tool you want to call. This must match the name defined in your `tools.yaml`.
|
- `<tool-name>`: The name of the tool you want to call. This must match the name defined in your `tools.yaml`.
|
||||||
- `[params]`: (Optional) A JSON string representing the arguments for the tool.
|
- `[params]`: (Optional) A JSON string representing the arguments for the tool.
|
||||||
|
|
||||||
|
|||||||
@@ -32,8 +32,7 @@ description: >
|
|||||||
|
|
||||||
## Sub Commands
|
## Sub Commands
|
||||||
|
|
||||||
<details>
|
### `invoke`
|
||||||
<summary><code>invoke</code></summary>
|
|
||||||
|
|
||||||
Executes a tool directly with the provided parameters. This is useful for testing tool configurations and parameters without needing a full client setup.
|
Executes a tool directly with the provided parameters. This is useful for testing tool configurations and parameters without needing a full client setup.
|
||||||
|
|
||||||
@@ -43,36 +42,8 @@ Executes a tool directly with the provided parameters. This is useful for testin
|
|||||||
toolbox invoke <tool-name> [params]
|
toolbox invoke <tool-name> [params]
|
||||||
```
|
```
|
||||||
|
|
||||||
**Arguments:**
|
- `<tool-name>`: The name of the tool to execute (as defined in your configuration).
|
||||||
|
- `[params]`: (Optional) A JSON string containing the parameters for the tool.
|
||||||
- `tool-name`: The name of the tool to execute (as defined in your configuration).
|
|
||||||
- `params`: (Optional) A JSON string containing the parameters for the tool.
|
|
||||||
|
|
||||||
For more detailed instructions, see [Invoke Tools via CLI](../how-to/invoke_tool.md).
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><code>skills-generate</code></summary>
|
|
||||||
|
|
||||||
Generates a skill package from a specified toolset. Each tool in the toolset will have a corresponding Node.js execution script in the generated skill.
|
|
||||||
|
|
||||||
**Syntax:**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
toolbox skills-generate --name <name> --description <description> --toolset <toolset> --output-dir <output>
|
|
||||||
```
|
|
||||||
|
|
||||||
**Flags:**
|
|
||||||
|
|
||||||
- `--name`: Name of the generated skill.
|
|
||||||
- `--description`: Description of the generated skill.
|
|
||||||
- `--toolset`: (Optional) Name of the toolset to convert into a skill. If not provided, all tools will be included.
|
|
||||||
- `--output-dir`: (Optional) Directory to output generated skills (default: "skills").
|
|
||||||
|
|
||||||
For more detailed instructions, see [Generate Agent Skills](../how-to/generate_skill.md).
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
|
|||||||
@@ -1,237 +0,0 @@
|
|||||||
// Copyright 2026 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package skills
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
_ "embed"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"sort"
|
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/log"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/server"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RootCommand defines the interface for required by skills-generate subcommand.
|
|
||||||
// This allows subcommands to access shared resources and functionality without
|
|
||||||
// direct coupling to the root command's implementation.
|
|
||||||
type RootCommand interface {
|
|
||||||
// Config returns a copy of the current server configuration.
|
|
||||||
Config() server.ServerConfig
|
|
||||||
|
|
||||||
// LoadConfig loads and merges the configuration from files, folders, and prebuilts.
|
|
||||||
LoadConfig(ctx context.Context) error
|
|
||||||
|
|
||||||
// Setup initializes the runtime environment, including logging and telemetry.
|
|
||||||
// It returns the updated context and a shutdown function to be called when finished.
|
|
||||||
Setup(ctx context.Context) (context.Context, func(context.Context) error, error)
|
|
||||||
|
|
||||||
// Logger returns the logger instance.
|
|
||||||
Logger() log.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// Command is the command for generating skills.
|
|
||||||
type Command struct {
|
|
||||||
*cobra.Command
|
|
||||||
rootCmd RootCommand
|
|
||||||
name string
|
|
||||||
description string
|
|
||||||
toolset string
|
|
||||||
outputDir string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewCommand creates a new Command.
|
|
||||||
func NewCommand(rootCmd RootCommand) *cobra.Command {
|
|
||||||
cmd := &Command{
|
|
||||||
rootCmd: rootCmd,
|
|
||||||
}
|
|
||||||
cmd.Command = &cobra.Command{
|
|
||||||
Use: "skills-generate",
|
|
||||||
Short: "Generate skills from tool configurations",
|
|
||||||
RunE: func(c *cobra.Command, args []string) error {
|
|
||||||
return cmd.run(c)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd.Flags().StringVar(&cmd.name, "name", "", "Name of the generated skill.")
|
|
||||||
cmd.Flags().StringVar(&cmd.description, "description", "", "Description of the generated skill")
|
|
||||||
cmd.Flags().StringVar(&cmd.toolset, "toolset", "", "Name of the toolset to convert into a skill. If not provided, all tools will be included.")
|
|
||||||
cmd.Flags().StringVar(&cmd.outputDir, "output-dir", "skills", "Directory to output generated skills")
|
|
||||||
|
|
||||||
_ = cmd.MarkFlagRequired("name")
|
|
||||||
_ = cmd.MarkFlagRequired("description")
|
|
||||||
return cmd.Command
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Command) run(cmd *cobra.Command) error {
|
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
ctx, shutdown, err := c.rootCmd.Setup(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = shutdown(ctx)
|
|
||||||
}()
|
|
||||||
|
|
||||||
logger := c.rootCmd.Logger()
|
|
||||||
|
|
||||||
// Load and merge tool configurations
|
|
||||||
if err := c.rootCmd.LoadConfig(ctx); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.MkdirAll(c.outputDir, 0755); err != nil {
|
|
||||||
errMsg := fmt.Errorf("error creating output directory: %w", err)
|
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
|
||||||
return errMsg
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.InfoContext(ctx, fmt.Sprintf("Generating skill '%s'...", c.name))
|
|
||||||
|
|
||||||
// Initialize toolbox and collect tools
|
|
||||||
allTools, err := c.collectTools(ctx)
|
|
||||||
if err != nil {
|
|
||||||
errMsg := fmt.Errorf("error collecting tools: %w", err)
|
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
|
||||||
return errMsg
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(allTools) == 0 {
|
|
||||||
logger.InfoContext(ctx, "No tools found to generate.")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate the combined skill directory
|
|
||||||
skillPath := filepath.Join(c.outputDir, c.name)
|
|
||||||
if err := os.MkdirAll(skillPath, 0755); err != nil {
|
|
||||||
errMsg := fmt.Errorf("error creating skill directory: %w", err)
|
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
|
||||||
return errMsg
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate assets directory
|
|
||||||
assetsPath := filepath.Join(skillPath, "assets")
|
|
||||||
if err := os.MkdirAll(assetsPath, 0755); err != nil {
|
|
||||||
errMsg := fmt.Errorf("error creating assets dir: %w", err)
|
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
|
||||||
return errMsg
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate scripts directory
|
|
||||||
scriptsPath := filepath.Join(skillPath, "scripts")
|
|
||||||
if err := os.MkdirAll(scriptsPath, 0755); err != nil {
|
|
||||||
errMsg := fmt.Errorf("error creating scripts dir: %w", err)
|
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
|
||||||
return errMsg
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate over keys to ensure deterministic order
|
|
||||||
var toolNames []string
|
|
||||||
for name := range allTools {
|
|
||||||
toolNames = append(toolNames, name)
|
|
||||||
}
|
|
||||||
sort.Strings(toolNames)
|
|
||||||
|
|
||||||
for _, toolName := range toolNames {
|
|
||||||
// Generate YAML config in asset directory
|
|
||||||
minimizedContent, err := generateToolConfigYAML(c.rootCmd.Config(), toolName)
|
|
||||||
if err != nil {
|
|
||||||
errMsg := fmt.Errorf("error generating filtered config for %s: %w", toolName, err)
|
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
|
||||||
return errMsg
|
|
||||||
}
|
|
||||||
|
|
||||||
specificToolsFileName := fmt.Sprintf("%s.yaml", toolName)
|
|
||||||
if minimizedContent != nil {
|
|
||||||
destPath := filepath.Join(assetsPath, specificToolsFileName)
|
|
||||||
if err := os.WriteFile(destPath, minimizedContent, 0644); err != nil {
|
|
||||||
errMsg := fmt.Errorf("error writing filtered config for %s: %w", toolName, err)
|
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
|
||||||
return errMsg
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate wrapper script in scripts directory
|
|
||||||
scriptContent, err := generateScriptContent(toolName, specificToolsFileName)
|
|
||||||
if err != nil {
|
|
||||||
errMsg := fmt.Errorf("error generating script content for %s: %w", toolName, err)
|
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
|
||||||
return errMsg
|
|
||||||
}
|
|
||||||
|
|
||||||
scriptFilename := filepath.Join(scriptsPath, fmt.Sprintf("%s.js", toolName))
|
|
||||||
if err := os.WriteFile(scriptFilename, []byte(scriptContent), 0755); err != nil {
|
|
||||||
errMsg := fmt.Errorf("error writing script %s: %w", scriptFilename, err)
|
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
|
||||||
return errMsg
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate SKILL.md
|
|
||||||
skillContent, err := generateSkillMarkdown(c.name, c.description, allTools)
|
|
||||||
if err != nil {
|
|
||||||
errMsg := fmt.Errorf("error generating SKILL.md content: %w", err)
|
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
|
||||||
return errMsg
|
|
||||||
}
|
|
||||||
skillMdPath := filepath.Join(skillPath, "SKILL.md")
|
|
||||||
if err := os.WriteFile(skillMdPath, []byte(skillContent), 0644); err != nil {
|
|
||||||
errMsg := fmt.Errorf("error writing SKILL.md: %w", err)
|
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
|
||||||
return errMsg
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.InfoContext(ctx, fmt.Sprintf("Successfully generated skill '%s' with %d tools.", c.name, len(allTools)))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Command) collectTools(ctx context.Context) (map[string]tools.Tool, error) {
|
|
||||||
// Initialize Resources
|
|
||||||
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, c.rootCmd.Config())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to initialize resources: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
resourceMgr := resources.NewResourceManager(sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
|
|
||||||
|
|
||||||
result := make(map[string]tools.Tool)
|
|
||||||
|
|
||||||
if c.toolset == "" {
|
|
||||||
return toolsMap, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ts, ok := resourceMgr.GetToolset(c.toolset)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("toolset %q not found", c.toolset)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, t := range ts.Tools {
|
|
||||||
if t != nil {
|
|
||||||
tool := *t
|
|
||||||
result[tool.McpManifest().Name] = tool
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
@@ -1,297 +0,0 @@
|
|||||||
// Copyright 2026 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package skills
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
"text/template"
|
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/server"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
|
||||||
)
|
|
||||||
|
|
||||||
const skillTemplate = `---
|
|
||||||
name: {{.SkillName}}
|
|
||||||
description: {{.SkillDescription}}
|
|
||||||
---
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
All scripts can be executed using Node.js. Replace ` + "`" + `<param_name>` + "`" + ` and ` + "`" + `<param_value>` + "`" + ` with actual values.
|
|
||||||
|
|
||||||
**Bash:**
|
|
||||||
` + "`" + `node scripts/<script_name>.js '{"<param_name>": "<param_value>"}'` + "`" + `
|
|
||||||
|
|
||||||
**PowerShell:**
|
|
||||||
` + "`" + `node scripts/<script_name>.js '{\"<param_name>\": \"<param_value>\"}'` + "`" + `
|
|
||||||
|
|
||||||
## Scripts
|
|
||||||
|
|
||||||
{{range .Tools}}
|
|
||||||
### {{.Name}}
|
|
||||||
|
|
||||||
{{.Description}}
|
|
||||||
|
|
||||||
{{.ParametersSchema}}
|
|
||||||
|
|
||||||
---
|
|
||||||
{{end}}
|
|
||||||
`
|
|
||||||
|
|
||||||
type toolTemplateData struct {
|
|
||||||
Name string
|
|
||||||
Description string
|
|
||||||
ParametersSchema string
|
|
||||||
}
|
|
||||||
|
|
||||||
type skillTemplateData struct {
|
|
||||||
SkillName string
|
|
||||||
SkillDescription string
|
|
||||||
Tools []toolTemplateData
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateSkillMarkdown generates the content of the SKILL.md file.
|
|
||||||
// It includes usage instructions and a reference section for each tool in the skill,
|
|
||||||
// detailing its description and parameters.
|
|
||||||
func generateSkillMarkdown(skillName, skillDescription string, toolsMap map[string]tools.Tool) (string, error) {
|
|
||||||
var toolsData []toolTemplateData
|
|
||||||
|
|
||||||
// Order tools based on name
|
|
||||||
var toolNames []string
|
|
||||||
for name := range toolsMap {
|
|
||||||
toolNames = append(toolNames, name)
|
|
||||||
}
|
|
||||||
sort.Strings(toolNames)
|
|
||||||
|
|
||||||
for _, name := range toolNames {
|
|
||||||
tool := toolsMap[name]
|
|
||||||
manifest := tool.Manifest()
|
|
||||||
|
|
||||||
parametersSchema, err := formatParameters(manifest.Parameters)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
toolsData = append(toolsData, toolTemplateData{
|
|
||||||
Name: name,
|
|
||||||
Description: manifest.Description,
|
|
||||||
ParametersSchema: parametersSchema,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
data := skillTemplateData{
|
|
||||||
SkillName: skillName,
|
|
||||||
SkillDescription: skillDescription,
|
|
||||||
Tools: toolsData,
|
|
||||||
}
|
|
||||||
|
|
||||||
tmpl, err := template.New("markdown").Parse(skillTemplate)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("error parsing markdown template: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf strings.Builder
|
|
||||||
if err := tmpl.Execute(&buf, data); err != nil {
|
|
||||||
return "", fmt.Errorf("error executing markdown template: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return buf.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const nodeScriptTemplate = `#!/usr/bin/env node
|
|
||||||
|
|
||||||
const { spawn, execSync } = require('child_process');
|
|
||||||
const path = require('path');
|
|
||||||
const fs = require('fs');
|
|
||||||
|
|
||||||
const toolName = "{{.Name}}";
|
|
||||||
const toolsFileName = "{{.ToolsFileName}}";
|
|
||||||
|
|
||||||
function getToolboxPath() {
|
|
||||||
try {
|
|
||||||
const checkCommand = process.platform === 'win32' ? 'where toolbox' : 'which toolbox';
|
|
||||||
const globalPath = execSync(checkCommand, { stdio: 'pipe', encoding: 'utf-8' }).trim();
|
|
||||||
if (globalPath) {
|
|
||||||
return globalPath.split('\n')[0].trim();
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
// Ignore error;
|
|
||||||
}
|
|
||||||
const localPath = path.resolve(__dirname, '../../../toolbox');
|
|
||||||
if (fs.existsSync(localPath)) {
|
|
||||||
return localPath;
|
|
||||||
}
|
|
||||||
throw new Error("Toolbox binary not found");
|
|
||||||
}
|
|
||||||
|
|
||||||
let toolboxBinary;
|
|
||||||
try {
|
|
||||||
toolboxBinary = getToolboxPath();
|
|
||||||
} catch (err) {
|
|
||||||
console.error("Error:", err.message);
|
|
||||||
process.exit(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
let configArgs = [];
|
|
||||||
if (toolsFileName) {
|
|
||||||
configArgs.push("--tools-file", path.join(__dirname, "..", "assets", toolsFileName));
|
|
||||||
}
|
|
||||||
|
|
||||||
const args = process.argv.slice(2);
|
|
||||||
const toolboxArgs = [...configArgs, "invoke", toolName, ...args];
|
|
||||||
|
|
||||||
const child = spawn(toolboxBinary, toolboxArgs, { stdio: 'inherit' });
|
|
||||||
|
|
||||||
child.on('close', (code) => {
|
|
||||||
process.exit(code);
|
|
||||||
});
|
|
||||||
|
|
||||||
child.on('error', (err) => {
|
|
||||||
console.error("Error executing toolbox:", err);
|
|
||||||
process.exit(1);
|
|
||||||
});
|
|
||||||
`
|
|
||||||
|
|
||||||
type scriptData struct {
|
|
||||||
Name string
|
|
||||||
ToolsFileName string
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateScriptContent creates the content for a Node.js wrapper script.
|
|
||||||
// This script invokes the toolbox CLI with the appropriate configuration
|
|
||||||
// (using a generated tools file) and arguments to execute the specific tool.
|
|
||||||
func generateScriptContent(name string, toolsFileName string) (string, error) {
|
|
||||||
data := scriptData{
|
|
||||||
Name: name,
|
|
||||||
ToolsFileName: toolsFileName,
|
|
||||||
}
|
|
||||||
|
|
||||||
tmpl, err := template.New("script").Parse(nodeScriptTemplate)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("error parsing script template: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf strings.Builder
|
|
||||||
if err := tmpl.Execute(&buf, data); err != nil {
|
|
||||||
return "", fmt.Errorf("error executing script template: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return buf.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// formatParameters converts a list of parameter manifests into a formatted JSON schema string.
|
|
||||||
// This schema is used in the skill documentation to describe the input parameters for a tool.
|
|
||||||
func formatParameters(params []parameters.ParameterManifest) (string, error) {
|
|
||||||
if len(params) == 0 {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
properties := make(map[string]interface{})
|
|
||||||
var required []string
|
|
||||||
|
|
||||||
for _, p := range params {
|
|
||||||
paramMap := map[string]interface{}{
|
|
||||||
"type": p.Type,
|
|
||||||
"description": p.Description,
|
|
||||||
}
|
|
||||||
if p.Default != nil {
|
|
||||||
paramMap["default"] = p.Default
|
|
||||||
}
|
|
||||||
properties[p.Name] = paramMap
|
|
||||||
if p.Required {
|
|
||||||
required = append(required, p.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
schema := map[string]interface{}{
|
|
||||||
"type": "object",
|
|
||||||
"properties": properties,
|
|
||||||
}
|
|
||||||
if len(required) > 0 {
|
|
||||||
schema["required"] = required
|
|
||||||
}
|
|
||||||
|
|
||||||
schemaJSON, err := json.MarshalIndent(schema, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("error generating parameters schema: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("#### Parameters\n\n```json\n%s\n```", string(schemaJSON)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateToolConfigYAML generates the YAML configuration for a single tool and its dependency (source).
|
|
||||||
// It extracts the relevant tool and source configurations from the server config and formats them
|
|
||||||
// into a YAML document suitable for inclusion in the skill's assets.
|
|
||||||
func generateToolConfigYAML(cfg server.ServerConfig, toolName string) ([]byte, error) {
|
|
||||||
toolCfg, ok := cfg.ToolConfigs[toolName]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("error finding tool config: %s", toolName)
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
encoder := yaml.NewEncoder(&buf)
|
|
||||||
|
|
||||||
// Process Tool Config
|
|
||||||
toolWrapper := struct {
|
|
||||||
Kind string `yaml:"kind"`
|
|
||||||
Name string `yaml:"name"`
|
|
||||||
Config tools.ToolConfig `yaml:",inline"`
|
|
||||||
}{
|
|
||||||
Kind: "tools",
|
|
||||||
Name: toolName,
|
|
||||||
Config: toolCfg,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := encoder.Encode(toolWrapper); err != nil {
|
|
||||||
return nil, fmt.Errorf("error encoding tool config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process Source Config
|
|
||||||
var toolMap map[string]interface{}
|
|
||||||
b, err := yaml.Marshal(toolCfg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error marshaling tool config: %w", err)
|
|
||||||
}
|
|
||||||
if err := yaml.Unmarshal(b, &toolMap); err != nil {
|
|
||||||
return nil, fmt.Errorf("error unmarshaling tool config map: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if sourceName, ok := toolMap["source"].(string); ok {
|
|
||||||
if sourceCfg, ok := cfg.SourceConfigs[sourceName]; ok {
|
|
||||||
sourceWrapper := struct {
|
|
||||||
Kind string `yaml:"kind"`
|
|
||||||
Name string `yaml:"name"`
|
|
||||||
Config sources.SourceConfig `yaml:",inline"`
|
|
||||||
}{
|
|
||||||
Kind: "sources",
|
|
||||||
Name: sourceName,
|
|
||||||
Config: sourceCfg,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := encoder.Encode(sourceWrapper); err != nil {
|
|
||||||
return nil, fmt.Errorf("error encoding source config: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return buf.Bytes(), nil
|
|
||||||
}
|
|
||||||
@@ -1,338 +0,0 @@
|
|||||||
// Copyright 2026 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package skills
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/server"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
|
||||||
"go.opentelemetry.io/otel/trace"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MockToolConfig struct {
|
|
||||||
Type string `yaml:"type"`
|
|
||||||
Source string `yaml:"source"`
|
|
||||||
Other string `yaml:"other"`
|
|
||||||
Parameters parameters.Parameters `yaml:"parameters"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m MockToolConfig) ToolConfigType() string {
|
|
||||||
return m.Type
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m MockToolConfig) Initialize(map[string]sources.Source) (tools.Tool, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockSourceConfig struct {
|
|
||||||
TypeVal string
|
|
||||||
ConnVal string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m MockSourceConfig) SourceConfigType() string {
|
|
||||||
return m.TypeVal
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m MockSourceConfig) Initialize(context.Context, trace.Tracer) (sources.Source, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m MockSourceConfig) MarshalYAML() (interface{}, error) {
|
|
||||||
return map[string]interface{}{
|
|
||||||
"connection_string": m.ConnVal,
|
|
||||||
"type": m.TypeVal,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFormatParameters(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
params []parameters.ParameterManifest
|
|
||||||
wantContains []string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "empty parameters",
|
|
||||||
params: []parameters.ParameterManifest{},
|
|
||||||
wantContains: []string{""},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "single required string parameter",
|
|
||||||
params: []parameters.ParameterManifest{
|
|
||||||
{
|
|
||||||
Name: "param1",
|
|
||||||
Description: "A test parameter",
|
|
||||||
Type: "string",
|
|
||||||
Required: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantContains: []string{
|
|
||||||
"## Parameters",
|
|
||||||
"```json",
|
|
||||||
`"type": "object"`,
|
|
||||||
`"properties": {`,
|
|
||||||
`"param1": {`,
|
|
||||||
`"type": "string"`,
|
|
||||||
`"description": "A test parameter"`,
|
|
||||||
`"required": [`,
|
|
||||||
`"param1"`,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "mixed parameters with defaults",
|
|
||||||
params: []parameters.ParameterManifest{
|
|
||||||
{
|
|
||||||
Name: "param1",
|
|
||||||
Description: "Param 1",
|
|
||||||
Type: "string",
|
|
||||||
Required: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "param2",
|
|
||||||
Description: "Param 2",
|
|
||||||
Type: "integer",
|
|
||||||
Default: 42,
|
|
||||||
Required: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantContains: []string{
|
|
||||||
`"param1": {`,
|
|
||||||
`"param2": {`,
|
|
||||||
`"default": 42`,
|
|
||||||
`"required": [`,
|
|
||||||
`"param1"`,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := formatParameters(tt.params)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("formatParameters() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if tt.wantErr {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(tt.params) == 0 {
|
|
||||||
if got != "" {
|
|
||||||
t.Errorf("formatParameters() = %v, want empty string", got)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, want := range tt.wantContains {
|
|
||||||
if !strings.Contains(got, want) {
|
|
||||||
t.Errorf("formatParameters() result missing expected string: %s\nGot:\n%s", want, got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateSkillMarkdown(t *testing.T) {
|
|
||||||
toolsMap := map[string]tools.Tool{
|
|
||||||
"tool1": server.MockTool{
|
|
||||||
Description: "First tool",
|
|
||||||
Params: []parameters.Parameter{
|
|
||||||
parameters.NewStringParameter("p1", "d1"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
got, err := generateSkillMarkdown("MySkill", "My Description", toolsMap)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("generateSkillMarkdown() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedSubstrings := []string{
|
|
||||||
"name: MySkill",
|
|
||||||
"description: My Description",
|
|
||||||
"## Usage",
|
|
||||||
"All scripts can be executed using Node.js",
|
|
||||||
"**Bash:**",
|
|
||||||
"`node scripts/<script_name>.js '{\"<param_name>\": \"<param_value>\"}'`",
|
|
||||||
"**PowerShell:**",
|
|
||||||
"`node scripts/<script_name>.js '{\"<param_name>\": \"<param_value>\"}'`",
|
|
||||||
"## Scripts",
|
|
||||||
"### tool1",
|
|
||||||
"First tool",
|
|
||||||
"## Parameters",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, s := range expectedSubstrings {
|
|
||||||
if !strings.Contains(got, s) {
|
|
||||||
t.Errorf("generateSkillMarkdown() missing substring %q", s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateScriptContent(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
toolName string
|
|
||||||
toolsFileName string
|
|
||||||
wantContains []string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "basic script",
|
|
||||||
toolName: "test-tool",
|
|
||||||
toolsFileName: "",
|
|
||||||
wantContains: []string{
|
|
||||||
`const toolName = "test-tool";`,
|
|
||||||
`const toolsFileName = "";`,
|
|
||||||
`const toolboxArgs = [...configArgs, "invoke", toolName, ...args];`,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "script with tools file",
|
|
||||||
toolName: "complex-tool",
|
|
||||||
toolsFileName: "tools.yaml",
|
|
||||||
wantContains: []string{
|
|
||||||
`const toolName = "complex-tool";`,
|
|
||||||
`const toolsFileName = "tools.yaml";`,
|
|
||||||
`configArgs.push("--tools-file", path.join(__dirname, "..", "assets", toolsFileName));`,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := generateScriptContent(tt.toolName, tt.toolsFileName)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("generateScriptContent() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, s := range tt.wantContains {
|
|
||||||
if !strings.Contains(got, s) {
|
|
||||||
t.Errorf("generateScriptContent() missing substring %q\nGot:\n%s", s, got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateToolConfigYAML(t *testing.T) {
|
|
||||||
cfg := server.ServerConfig{
|
|
||||||
ToolConfigs: server.ToolConfigs{
|
|
||||||
"tool1": MockToolConfig{
|
|
||||||
Type: "custom-tool",
|
|
||||||
Source: "src1",
|
|
||||||
Other: "foo",
|
|
||||||
},
|
|
||||||
"toolNoSource": MockToolConfig{
|
|
||||||
Type: "http",
|
|
||||||
},
|
|
||||||
"toolWithParams": MockToolConfig{
|
|
||||||
Type: "custom-tool",
|
|
||||||
Parameters: []parameters.Parameter{
|
|
||||||
parameters.NewStringParameter("param1", "desc1"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
SourceConfigs: server.SourceConfigs{
|
|
||||||
"src1": MockSourceConfig{
|
|
||||||
TypeVal: "postgres",
|
|
||||||
ConnVal: "conn1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
toolName string
|
|
||||||
wantContains []string
|
|
||||||
wantErr bool
|
|
||||||
wantNil bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "tool with source",
|
|
||||||
toolName: "tool1",
|
|
||||||
wantContains: []string{
|
|
||||||
"kind: tools",
|
|
||||||
"name: tool1",
|
|
||||||
"type: custom-tool",
|
|
||||||
"source: src1",
|
|
||||||
"other: foo",
|
|
||||||
"---",
|
|
||||||
"kind: sources",
|
|
||||||
"name: src1",
|
|
||||||
"type: postgres",
|
|
||||||
"connection_string: conn1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "tool without source",
|
|
||||||
toolName: "toolNoSource",
|
|
||||||
wantContains: []string{
|
|
||||||
"kind: tools",
|
|
||||||
"name: toolNoSource",
|
|
||||||
"type: http",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "tool with parameters",
|
|
||||||
toolName: "toolWithParams",
|
|
||||||
wantContains: []string{
|
|
||||||
"kind: tools",
|
|
||||||
"name: toolWithParams",
|
|
||||||
"type: custom-tool",
|
|
||||||
"parameters:",
|
|
||||||
"- name: param1",
|
|
||||||
"type: string",
|
|
||||||
"description: desc1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "non-existent tool",
|
|
||||||
toolName: "missing-tool",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
gotBytes, err := generateToolConfigYAML(cfg, tt.toolName)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("generateToolConfigYAML() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if tt.wantErr {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.wantNil {
|
|
||||||
if gotBytes != nil {
|
|
||||||
t.Errorf("generateToolConfigYAML() expected nil, got %s", string(gotBytes))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
got := string(gotBytes)
|
|
||||||
for _, want := range tt.wantContains {
|
|
||||||
if !strings.Contains(got, want) {
|
|
||||||
t.Errorf("generateToolConfigYAML() result missing expected string: %q\nGot:\n%s", want, got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -19,7 +19,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
@@ -235,8 +234,10 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
params, err := parameters.ParseParams(tool.GetParameters(), data, claimsFromAuth)
|
params, err := parameters.ParseParams(tool.GetParameters(), data, claimsFromAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If auth error, return 401
|
// If auth error, return 401
|
||||||
if errors.Is(err, util.ErrUnauthorized) {
|
errMsg := fmt.Sprintf("error parsing authenticated parameters from ID token: %w", err)
|
||||||
s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err))
|
var clientServerErr *util.ClientServerError
|
||||||
|
if errors.As(err, &clientServerErr) && clientServerErr.Code == http.StatusUnauthorized {
|
||||||
|
s.logger.DebugContext(ctx, errMsg)
|
||||||
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
|
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -259,35 +260,50 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// Determine what error to return to the users.
|
// Determine what error to return to the users.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr := err.Error()
|
var tbErr util.ToolboxError
|
||||||
var statusCode int
|
|
||||||
|
|
||||||
// Upstream API auth error propagation
|
if errors.As(err, &tbErr) {
|
||||||
switch {
|
switch tbErr.Category() {
|
||||||
case strings.Contains(errStr, "Error 401"):
|
case util.CategoryAgent:
|
||||||
statusCode = http.StatusUnauthorized
|
// Agent Errors -> 200 OK
|
||||||
case strings.Contains(errStr, "Error 403"):
|
s.logger.DebugContext(ctx, fmt.Sprintf("Tool invocation agent error: %v", err))
|
||||||
statusCode = http.StatusForbidden
|
_ = render.Render(w, r, newErrResponse(err, http.StatusOK))
|
||||||
|
return
|
||||||
|
|
||||||
|
case util.CategoryServer:
|
||||||
|
// Server Errors -> Check the specific code inside
|
||||||
|
var clientServerErr *util.ClientServerError
|
||||||
|
statusCode := http.StatusInternalServerError // Default to 500
|
||||||
|
|
||||||
|
if errors.As(err, &clientServerErr) {
|
||||||
|
if clientServerErr.Code != 0 {
|
||||||
|
statusCode = clientServerErr.Code
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process auth error
|
||||||
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
// Propagate the original 401/403 error.
|
// Token error, pass through 401/403
|
||||||
s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err))
|
s.logger.DebugContext(ctx, fmt.Sprintf("Client credentials lack authorization: %v", err))
|
||||||
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// ADC lacking permission or credentials configuration error.
|
// ADC/Config error, return 500
|
||||||
internalErr := fmt.Errorf("unexpected auth error occured during Tool invocation: %w", err)
|
statusCode = http.StatusInternalServerError
|
||||||
s.logger.ErrorContext(ctx, internalErr.Error())
|
}
|
||||||
_ = render.Render(w, r, newErrResponse(internalErr, http.StatusInternalServerError))
|
|
||||||
|
s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation server error: %v", err))
|
||||||
|
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = fmt.Errorf("error while invoking tool: %w", err)
|
} else {
|
||||||
s.logger.DebugContext(ctx, err.Error())
|
// Unknown error -> 500
|
||||||
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation unknown error: %v", err))
|
||||||
|
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
resMarshal, err := json.Marshal(res)
|
resMarshal, err := json.Marshal(res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/log"
|
"github.com/googleapis/genai-toolbox/internal/log"
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||||
@@ -40,6 +41,140 @@ var (
|
|||||||
_ prompts.Prompt = MockPrompt{}
|
_ prompts.Prompt = MockPrompt{}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MockTool is used to mock tools in tests
|
||||||
|
type MockTool struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
Params []parameters.Parameter
|
||||||
|
manifest tools.Manifest
|
||||||
|
unauthorized bool
|
||||||
|
requiresClientAuthrorization bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MockTool) Invoke(context.Context, tools.SourceProvider, parameters.ParamValues, tools.AccessToken) (any, error) {
|
||||||
|
mock := []any{t.Name}
|
||||||
|
return mock, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MockTool) ToConfig() tools.ToolConfig {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// claims is a map of user info decoded from an auth token
|
||||||
|
func (t MockTool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
return parameters.ParseParams(t.Params, data, claimsMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MockTool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
return parameters.EmbedParams(ctx, t.Params, paramValues, embeddingModelsMap, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MockTool) Manifest() tools.Manifest {
|
||||||
|
pMs := make([]parameters.ParameterManifest, 0, len(t.Params))
|
||||||
|
for _, p := range t.Params {
|
||||||
|
pMs = append(pMs, p.Manifest())
|
||||||
|
}
|
||||||
|
return tools.Manifest{Description: t.Description, Parameters: pMs}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MockTool) Authorized(verifiedAuthServices []string) bool {
|
||||||
|
// defaulted to true
|
||||||
|
return !t.unauthorized
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) (bool, error) {
|
||||||
|
// defaulted to false
|
||||||
|
return t.requiresClientAuthrorization, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MockTool) GetParameters() parameters.Parameters {
|
||||||
|
return t.Params
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MockTool) McpManifest() tools.McpManifest {
|
||||||
|
properties := make(map[string]parameters.ParameterMcpManifest)
|
||||||
|
required := make([]string, 0)
|
||||||
|
authParams := make(map[string][]string)
|
||||||
|
|
||||||
|
for _, p := range t.Params {
|
||||||
|
name := p.GetName()
|
||||||
|
paramManifest, authParamList := p.McpManifest()
|
||||||
|
properties[name] = paramManifest
|
||||||
|
required = append(required, name)
|
||||||
|
|
||||||
|
if len(authParamList) > 0 {
|
||||||
|
authParams[name] = authParamList
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
toolsSchema := parameters.McpToolsSchema{
|
||||||
|
Type: "object",
|
||||||
|
Properties: properties,
|
||||||
|
Required: required,
|
||||||
|
}
|
||||||
|
|
||||||
|
mcpManifest := tools.McpManifest{
|
||||||
|
Name: t.Name,
|
||||||
|
Description: t.Description,
|
||||||
|
InputSchema: toolsSchema,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(authParams) > 0 {
|
||||||
|
mcpManifest.Metadata = map[string]any{
|
||||||
|
"toolbox/authParams": authParams,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return mcpManifest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t MockTool) GetAuthTokenHeaderName(tools.SourceProvider) (string, error) {
|
||||||
|
return "Authorization", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockPrompt is used to mock prompts in tests
|
||||||
|
type MockPrompt struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
Args prompts.Arguments
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MockPrompt) SubstituteParams(vals parameters.ParamValues) (any, error) {
|
||||||
|
return []prompts.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: fmt.Sprintf("substituted %s", p.Name),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MockPrompt) ParseArgs(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
var params parameters.Parameters
|
||||||
|
for _, arg := range p.Args {
|
||||||
|
params = append(params, arg.Parameter)
|
||||||
|
}
|
||||||
|
return parameters.ParseParams(params, data, claimsMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MockPrompt) Manifest() prompts.Manifest {
|
||||||
|
var argManifests []parameters.ParameterManifest
|
||||||
|
for _, arg := range p.Args {
|
||||||
|
argManifests = append(argManifests, arg.Manifest())
|
||||||
|
}
|
||||||
|
return prompts.Manifest{
|
||||||
|
Description: p.Description,
|
||||||
|
Arguments: argManifests,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MockPrompt) McpManifest() prompts.McpManifest {
|
||||||
|
return prompts.GetMcpManifest(p.Name, p.Description, p.Args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MockPrompt) ToConfig() prompts.PromptConfig {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var tool1 = MockTool{
|
var tool1 = MockTool{
|
||||||
Name: "no_params",
|
Name: "no_params",
|
||||||
Params: []parameters.Parameter{},
|
Params: []parameters.Parameter{},
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -444,18 +443,20 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
code := rpcResponse.Error.Code
|
code := rpcResponse.Error.Code
|
||||||
switch code {
|
switch code {
|
||||||
case jsonrpc.INTERNAL_ERROR:
|
case jsonrpc.INTERNAL_ERROR:
|
||||||
|
// Map Internal RPC Error (-32603) to HTTP 500
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
case jsonrpc.INVALID_REQUEST:
|
case jsonrpc.INVALID_REQUEST:
|
||||||
errStr := err.Error()
|
var clientServerErr *util.ClientServerError
|
||||||
if errors.Is(err, util.ErrUnauthorized) {
|
if errors.As(err, &clientServerErr) {
|
||||||
|
switch clientServerErr.Code {
|
||||||
|
case http.StatusUnauthorized:
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
} else if strings.Contains(errStr, "Error 401") {
|
case http.StatusForbidden:
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
|
||||||
} else if strings.Contains(errStr, "Error 403") {
|
|
||||||
w.WriteHeader(http.StatusForbidden)
|
w.WriteHeader(http.StatusForbidden)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// send HTTP response
|
// send HTTP response
|
||||||
render.JSON(w, r, res)
|
render.JSON(w, r, res)
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -124,7 +123,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.NewClientServerError(
|
||||||
|
"missing access token in the 'Authorization' header",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,7 +175,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
err = util.NewClientServerError(
|
||||||
|
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -194,21 +201,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr := err.Error()
|
var tbErr util.ToolboxError
|
||||||
// Missing authService tokens.
|
|
||||||
if errors.Is(err, util.ErrUnauthorized) {
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
// Upstream auth error
|
|
||||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
|
||||||
if clientAuth {
|
|
||||||
// Error with client credentials should pass down to the client
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
// Auth error with ADC should raise internal 500 error
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if errors.As(err, &tbErr) {
|
||||||
|
switch tbErr.Category() {
|
||||||
|
case util.CategoryAgent:
|
||||||
|
// MCP - Tool execution error
|
||||||
|
// Return SUCCESS but with IsError: true
|
||||||
text := TextContent{
|
text := TextContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: err.Error(),
|
Text: err.Error(),
|
||||||
@@ -218,6 +217,28 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
Id: id,
|
Id: id,
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
|
case util.CategoryServer:
|
||||||
|
// MCP Spec - Protocol error
|
||||||
|
// Return JSON-RPC ERROR
|
||||||
|
var clientServerErr *util.ClientServerError
|
||||||
|
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
||||||
|
|
||||||
|
if errors.As(err, &clientServerErr) {
|
||||||
|
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
||||||
|
if clientAuth {
|
||||||
|
rpcCode = jsonrpc.INVALID_REQUEST
|
||||||
|
} else {
|
||||||
|
rpcCode = jsonrpc.INTERNAL_ERROR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Unknown error -> 500
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -124,7 +123,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.NewClientServerError(
|
||||||
|
"missing access token in the 'Authorization' header",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,7 +175,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
err = util.NewClientServerError(
|
||||||
|
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -194,20 +201,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr := err.Error()
|
var tbErr util.ToolboxError
|
||||||
// Missing authService tokens.
|
|
||||||
if errors.Is(err, util.ErrUnauthorized) {
|
if errors.As(err, &tbErr) {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
switch tbErr.Category() {
|
||||||
}
|
case util.CategoryAgent:
|
||||||
// Upstream auth error
|
// MCP - Tool execution error
|
||||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
// Return SUCCESS but with IsError: true
|
||||||
if clientAuth {
|
|
||||||
// Error with client credentials should pass down to the client
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
// Auth error with ADC should raise internal 500 error
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
text := TextContent{
|
text := TextContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: err.Error(),
|
Text: err.Error(),
|
||||||
@@ -217,8 +217,29 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
Id: id,
|
Id: id,
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
|
||||||
|
|
||||||
|
case util.CategoryServer:
|
||||||
|
// MCP Spec - Protocol error
|
||||||
|
// Return JSON-RPC ERROR
|
||||||
|
var clientServerErr *util.ClientServerError
|
||||||
|
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
||||||
|
|
||||||
|
if errors.As(err, &clientServerErr) {
|
||||||
|
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
||||||
|
if clientAuth {
|
||||||
|
rpcCode = jsonrpc.INVALID_REQUEST
|
||||||
|
} else {
|
||||||
|
rpcCode = jsonrpc.INTERNAL_ERROR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Unknown error -> 500
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
}
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|
||||||
sliceRes, ok := results.([]any)
|
sliceRes, ok := results.([]any)
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -117,7 +116,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
errMsg := "missing access token in the 'Authorization' header"
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, errMsg, nil), util.NewClientServerError(
|
||||||
|
errMsg,
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,7 +169,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
err = util.NewClientServerError(
|
||||||
|
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -187,20 +195,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr := err.Error()
|
var tbErr util.ToolboxError
|
||||||
// Missing authService tokens.
|
|
||||||
if errors.Is(err, util.ErrUnauthorized) {
|
if errors.As(err, &tbErr) {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
switch tbErr.Category() {
|
||||||
}
|
case util.CategoryAgent:
|
||||||
// Upstream auth error
|
// MCP - Tool execution error
|
||||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
// Return SUCCESS but with IsError: true
|
||||||
if clientAuth {
|
|
||||||
// Error with client credentials should pass down to the client
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
// Auth error with ADC should raise internal 500 error
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
text := TextContent{
|
text := TextContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: err.Error(),
|
Text: err.Error(),
|
||||||
@@ -210,6 +211,28 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
Id: id,
|
Id: id,
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
|
case util.CategoryServer:
|
||||||
|
// MCP Spec - Protocol error
|
||||||
|
// Return JSON-RPC ERROR
|
||||||
|
var clientServerErr *util.ClientServerError
|
||||||
|
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
||||||
|
|
||||||
|
if errors.As(err, &clientServerErr) {
|
||||||
|
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
||||||
|
if clientAuth {
|
||||||
|
rpcCode = jsonrpc.INVALID_REQUEST
|
||||||
|
} else {
|
||||||
|
rpcCode = jsonrpc.INTERNAL_ERROR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Unknown error -> 500
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -117,7 +116,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.NewClientServerError(
|
||||||
|
"missing access token in the 'Authorization' header",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,7 +168,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
err = util.NewClientServerError(
|
||||||
|
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -187,20 +194,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr := err.Error()
|
var tbErr util.ToolboxError
|
||||||
// Missing authService tokens.
|
|
||||||
if errors.Is(err, util.ErrUnauthorized) {
|
if errors.As(err, &tbErr) {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
switch tbErr.Category() {
|
||||||
}
|
case util.CategoryAgent:
|
||||||
// Upstream auth error
|
// MCP - Tool execution error
|
||||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
// Return SUCCESS but with IsError: true
|
||||||
if clientAuth {
|
|
||||||
// Error with client credentials should pass down to the client
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
// Auth error with ADC should raise internal 500 error
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
text := TextContent{
|
text := TextContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: err.Error(),
|
Text: err.Error(),
|
||||||
@@ -210,6 +210,28 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
Id: id,
|
Id: id,
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
|
case util.CategoryServer:
|
||||||
|
// MCP Spec - Protocol error
|
||||||
|
// Return JSON-RPC ERROR
|
||||||
|
var clientServerErr *util.ClientServerError
|
||||||
|
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
||||||
|
|
||||||
|
if errors.As(err, &clientServerErr) {
|
||||||
|
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
||||||
|
if clientAuth {
|
||||||
|
rpcCode = jsonrpc.INVALID_REQUEST
|
||||||
|
} else {
|
||||||
|
rpcCode = jsonrpc.INTERNAL_ERROR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Unknown error -> 500
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|||||||
@@ -1,159 +0,0 @@
|
|||||||
// Copyright 2026 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MockTool is used to mock tools in tests
|
|
||||||
type MockTool struct {
|
|
||||||
Name string
|
|
||||||
Description string
|
|
||||||
Params []parameters.Parameter
|
|
||||||
manifest tools.Manifest
|
|
||||||
unauthorized bool
|
|
||||||
requiresClientAuthrorization bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t MockTool) Invoke(context.Context, tools.SourceProvider, parameters.ParamValues, tools.AccessToken) (any, error) {
|
|
||||||
mock := []any{t.Name}
|
|
||||||
return mock, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t MockTool) ToConfig() tools.ToolConfig {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// claims is a map of user info decoded from an auth token
|
|
||||||
func (t MockTool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) {
|
|
||||||
return parameters.ParseParams(t.Params, data, claimsMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t MockTool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
|
||||||
return parameters.EmbedParams(ctx, t.Params, paramValues, embeddingModelsMap, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t MockTool) Manifest() tools.Manifest {
|
|
||||||
pMs := make([]parameters.ParameterManifest, 0, len(t.Params))
|
|
||||||
for _, p := range t.Params {
|
|
||||||
pMs = append(pMs, p.Manifest())
|
|
||||||
}
|
|
||||||
return tools.Manifest{Description: t.Description, Parameters: pMs}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t MockTool) Authorized(verifiedAuthServices []string) bool {
|
|
||||||
// defaulted to true
|
|
||||||
return !t.unauthorized
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) (bool, error) {
|
|
||||||
// defaulted to false
|
|
||||||
return t.requiresClientAuthrorization, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t MockTool) GetParameters() parameters.Parameters {
|
|
||||||
return t.Params
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t MockTool) McpManifest() tools.McpManifest {
|
|
||||||
properties := make(map[string]parameters.ParameterMcpManifest)
|
|
||||||
required := make([]string, 0)
|
|
||||||
authParams := make(map[string][]string)
|
|
||||||
|
|
||||||
for _, p := range t.Params {
|
|
||||||
name := p.GetName()
|
|
||||||
paramManifest, authParamList := p.McpManifest()
|
|
||||||
properties[name] = paramManifest
|
|
||||||
required = append(required, name)
|
|
||||||
|
|
||||||
if len(authParamList) > 0 {
|
|
||||||
authParams[name] = authParamList
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
toolsSchema := parameters.McpToolsSchema{
|
|
||||||
Type: "object",
|
|
||||||
Properties: properties,
|
|
||||||
Required: required,
|
|
||||||
}
|
|
||||||
|
|
||||||
mcpManifest := tools.McpManifest{
|
|
||||||
Name: t.Name,
|
|
||||||
Description: t.Description,
|
|
||||||
InputSchema: toolsSchema,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(authParams) > 0 {
|
|
||||||
mcpManifest.Metadata = map[string]any{
|
|
||||||
"toolbox/authParams": authParams,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return mcpManifest
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t MockTool) GetAuthTokenHeaderName(tools.SourceProvider) (string, error) {
|
|
||||||
return "Authorization", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockPrompt is used to mock prompts in tests
|
|
||||||
type MockPrompt struct {
|
|
||||||
Name string
|
|
||||||
Description string
|
|
||||||
Args prompts.Arguments
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p MockPrompt) SubstituteParams(vals parameters.ParamValues) (any, error) {
|
|
||||||
return []prompts.Message{
|
|
||||||
{
|
|
||||||
Role: "user",
|
|
||||||
Content: fmt.Sprintf("substituted %s", p.Name),
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p MockPrompt) ParseArgs(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) {
|
|
||||||
var params parameters.Parameters
|
|
||||||
for _, arg := range p.Args {
|
|
||||||
params = append(params, arg.Parameter)
|
|
||||||
}
|
|
||||||
return parameters.ParseParams(params, data, claimsMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p MockPrompt) Manifest() prompts.Manifest {
|
|
||||||
var argManifests []parameters.ParameterManifest
|
|
||||||
for _, arg := range p.Args {
|
|
||||||
argManifests = append(argManifests, arg.Manifest())
|
|
||||||
}
|
|
||||||
return prompts.Manifest{
|
|
||||||
Description: p.Description,
|
|
||||||
Arguments: argManifests,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p MockPrompt) McpManifest() prompts.McpManifest {
|
|
||||||
return prompts.GetMcpManifest(p.Name, p.Description, p.Args)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p MockPrompt) ToConfig() prompts.PromptConfig {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -184,7 +184,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if source.UseClientAuthorization() {
|
if source.UseClientAuthorization() {
|
||||||
// Use client-side access token
|
// Use client-side access token
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", util.ErrUnauthorized)
|
return nil, util.NewClientServerError("tool is configured for client OAuth but no token was provided in the request header", http.StatusUnauthorized, nil)
|
||||||
}
|
}
|
||||||
tokenStr, err = accessToken.ParseBearerToken()
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ package tools
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -80,7 +81,7 @@ type AccessToken string
|
|||||||
func (token AccessToken) ParseBearerToken() (string, error) {
|
func (token AccessToken) ParseBearerToken() (string, error) {
|
||||||
headerParts := strings.Split(string(token), " ")
|
headerParts := strings.Split(string(token), " ")
|
||||||
if len(headerParts) != 2 || strings.ToLower(headerParts[0]) != "bearer" {
|
if len(headerParts) != 2 || strings.ToLower(headerParts[0]) != "bearer" {
|
||||||
return "", fmt.Errorf("authorization header must be in the format 'Bearer <token>': %w", util.ErrUnauthorized)
|
return "", util.NewClientServerError("authorization header must be in the format 'Bearer <token>'", http.StatusUnauthorized, nil)
|
||||||
}
|
}
|
||||||
return headerParts[1], nil
|
return headerParts[1], nil
|
||||||
}
|
}
|
||||||
|
|||||||
61
internal/util/errors.go
Normal file
61
internal/util/errors.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
package util
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type ErrorCategory string
|
||||||
|
|
||||||
|
const (
|
||||||
|
CategoryAgent ErrorCategory = "AGENT_ERROR"
|
||||||
|
CategoryServer ErrorCategory = "SERVER_ERROR"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ToolboxError is the interface all custom errors must satisfy
|
||||||
|
type ToolboxError interface {
|
||||||
|
error
|
||||||
|
Category() ErrorCategory
|
||||||
|
}
|
||||||
|
|
||||||
|
// Agent Errors return 200 to the sender
|
||||||
|
type AgentError struct {
|
||||||
|
Msg string
|
||||||
|
Cause error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AgentError) Error() string { return e.Msg }
|
||||||
|
|
||||||
|
func (e *AgentError) Category() ErrorCategory { return CategoryAgent }
|
||||||
|
|
||||||
|
func (e *AgentError) Unwrap() error { return e.Cause }
|
||||||
|
|
||||||
|
func NewAgentError(msg string, cause error) *AgentError {
|
||||||
|
return &AgentError{Msg: msg, Cause: cause}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientServerError returns 4XX/5XX error code
|
||||||
|
type ClientServerError struct {
|
||||||
|
Msg string
|
||||||
|
Code int
|
||||||
|
Cause error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ClientServerError) Error() string { return fmt.Sprintf("%s: %v", e.Msg, e.Cause) }
|
||||||
|
|
||||||
|
func (e *ClientServerError) Category() ErrorCategory { return CategoryServer }
|
||||||
|
|
||||||
|
func (e *ClientServerError) Unwrap() error { return e.Cause }
|
||||||
|
|
||||||
|
func NewClientServerError(msg string, code int, cause error) *ClientServerError {
|
||||||
|
return &ClientServerError{Msg: msg, Code: code, Cause: cause}
|
||||||
|
}
|
||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -118,7 +119,7 @@ func parseFromAuthService(paramAuthServices []ParamAuthService, claimsMap map[st
|
|||||||
}
|
}
|
||||||
return v, nil
|
return v, nil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("missing or invalid authentication header: %w", util.ErrUnauthorized)
|
return nil, util.NewClientServerError("missing or invalid authentication header", http.StatusUnauthorized, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckParamRequired checks if a parameter is required based on the required and default field.
|
// CheckParamRequired checks if a parameter is required based on the required and default field.
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -188,5 +187,3 @@ func InstrumentationFromContext(ctx context.Context) (*telemetry.Instrumentation
|
|||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unable to retrieve instrumentation")
|
return nil, fmt.Errorf("unable to retrieve instrumentation")
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrUnauthorized = errors.New("unauthorized")
|
|
||||||
|
|||||||
Reference in New Issue
Block a user