feat: Add HTTP Source and Tool (#332)

Add Source and Tool for tool invocation through HTTP requests.
This commit is contained in:
Wenxin Du
2025-04-02 14:52:35 -04:00
committed by GitHub
parent 559cb66791
commit 64da5b4efe
14 changed files with 1479 additions and 5 deletions

View File

@@ -214,6 +214,21 @@ steps:
- |
go test -race -v -tags=integration,dgraph ./tests
- id: "http"
name: golang:1
waitFor: ["install-dependencies"]
entrypoint: /bin/bash
env:
- "GOPATH=/gopath"
secretEnv: ["CLIENT_ID"]
volumes:
- name: "go"
path: "/gopath"
args:
- -c
- |
go test -race -v -tags=integration,http ./tests
availableSecrets:
secretManager:
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest

View File

@@ -43,3 +43,4 @@ run:
- dgraph
- mssql
- mysql
- http

View File

@@ -0,0 +1,42 @@
---
title: "HTTP"
linkTitle: "HTTP"
type: docs
weight: 1
description: >
The HTTP source enables the Toolbox to retrieve data from a remote server using HTTP requests.
---
## About
The HTTP Source allows the Gen AI Toolbox to retrieve data from arbitrary HTTP
endpoints. This enables Generative AI applications to access data from web APIs
and other HTTP-accessible resources.
## Example
```yaml
sources:
my-http-source:
kind: http
baseUrl: https://api.example.com/data
timeout: 10s # default to 30s
headers:
Authorization: Bearer YOUR_API_TOKEN
Content-Type: application/json
queryParams:
param1: value1
param2: value2
```
## Reference
| **field** | **type** | **required** | **description** |
|-------------|:-----------------:|:------------:|-----------------------------------------------------------------------------------------------------------------------------------|
| kind | string | true | Must be "http". |
| baseUrl | string | true | The base URL for the HTTP requests (e.g., `https://api.example.com`). |
| timeout | string | false | The timeout for HTTP requests (e.g., "5s", "1m", refer to [ParseDuration][parse-duration-doc] for more examples). Defaults to 30s. |
| headers | map[string]string | false | Default headers to include in the HTTP requests. |
| queryParams | map[string]string | false | Default query parameters to include in the HTTP requests. |
[parse-duration-doc]: https://pkg.go.dev/time#ParseDuration

View File

@@ -0,0 +1,225 @@
---
title: "http"
type: docs
weight: 1
description: >
A "http" tool sends out an HTTP request to an HTTP endpoint.
---
## About
The `http` tool allows you to make HTTP requests to APIs to retrieve data.
An HTTP request is the method by which a client communicates with a server to retrieve or manipulate resources.
Toolbox allows you to configure the request URL, method, headers, query parameters, and the request body for an HTTP Tool.
### URL
An HTTP request URL identifies the target the client wants to access.
Toolbox composes the request URL from the HTTP Source's `baseUrl` and the HTTP Tool's `path`.
For example, the following config allows you to reach different paths of the same server using multiple Tools:
```yaml
sources:
my-http-source:
kind: http
baseUrl: https://api.example.com
tools:
my-post-tool:
kind: http
source: my-http-source
method: POST
path: /update
description: Tool to update information to the example API
my-get-tool:
kind: http
source: my-http-source
method: GET
path: /search
description: Tool to search information from the example API
```
### Headers
An HTTP request header is a key-value pair sent by a client to a server, providing additional information about the request, such as the client's preferences, the request body content type, and other metadata.
Headers specified by the HTTP Tool are combined with its HTTP Source headers for the resulting HTTP request, and override the Source headers in case of conflict.
The HTTP Tool allows you to specify headers in two different ways:
- Static headers can be specified using the `headers` field, and will be the same for every invocation:
```yaml
my-http-tool:
kind: http
source: my-http-source
method: GET
path: /search
description: Tool to search data from API
headers:
Authorization: API_KEY
Content-Type: application/json
```
- Dynamic headers can be specified as parameters in the `headerParams` field. The `name` of the `headerParams` will be used as the header key, and the value is determined by the LLM input upon Tool invocation:
```yaml
my-http-tool:
kind: http
source: my-http-source
method: GET
path: /search
description: some description
headerParams:
- name: Content-Type # Example LLM input: "application/json"
description: request content type
type: string
```
### Query parameters
Query parameters are key-value pairs appended to a URL after a question mark (?) to provide additional information to the server for processing the request, like filtering or sorting data.
- Static request query parameters should be specified in the `path` as part of the URL itself:
```yaml
my-http-tool:
kind: http
source: my-http-source
method: GET
path: /search?language=en&id=1
description: Tool to search for item with ID 1 in English
```
- Dynamic request query parameters should be specified as parameters in the `queryParams` section:
```yaml
my-http-tool:
kind: http
source: my-http-source
method: GET
path: /search
description: Tool to search for item with ID
queryParams:
- name: id
description: item ID
type: integer
```
### Request body
The request body payload is a string that supports parameter replacement with following [Go template][go-template-doc]'s annotations.
The parameter names in the `requestBody` should be preceded by "." and enclosed by double curly brackets "{{}}". The values will be populated into the request body payload upon Tool invocation.
Example:
```yaml
my-http-tool:
kind: http
source: my-http-source
method: GET
path: /search
description: Tool to search for person with name and age
requestBody: |
{
"age": {{.age}},
"name": "{{.name}}"
}
bodyParams:
- name: age
description: age number
type: integer
- name: name
description: name string
type: string
```
#### Formatting Parameters
Some complex parameters (such as arrays) may require additional formatting to match the expected output. For convenience, you can specify one of the following pre-defined functions before the parameter name to format it:
##### JSON
The `json` keyword converts a parameter into a JSON format.
{{< notice note >}}
Using JSON may add quotes to the variable name for certain types (such as strings).
{{< /notice >}}
Example:
```yaml
requestBody: |
{
"age": {{json .age}},
"name": {{json .name}},
"nickname": "{{json .nickname}}",
"nameArray": {{json .nameArray}}
}
```
will send the following output:
```yaml
{
"age": 18,
"name": "Katherine",
"nickname": ""Kat"", # Duplicate quotes
"nameArray": ["A", "B", "C"]
}
```
## Example
```yaml
my-http-tool:
kind: http
source: my-http-source
method: GET
path: /search
description: some description
authRequired:
- my-google-auth-service
- other-auth-service
queryParams:
- name: country
description: some description
type: string
requestBody: |
{
"age": {{.age}},
"city": "{{.city}}"
}
bodyParams:
- name: age
description: age number
type: integer
- name: city
description: city string
type: string
headers:
Authorization: API_KEY
Content-Type: application/json
headerParams:
- name: Language
description: language string
type: string
```
## Reference
| **field** | **type** | **required** | **description** |
|--------------|:------------------------------------------:|:------------:|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| kind | string | true | Must be "http". |
| source | string | true | Name of the source the HTTP request should be sent to. |
| description | string | true | Description of the tool that is passed to the LLM. |
| path | string | true | The path of the HTTP request. You can include static query parameters in the path string. |
| method | string | true | The HTTP method to use (e.g., GET, POST, PUT, DELETE). |
| headers | map[string]string | false | A map of headers to include in the HTTP request (overrides source headers). |
| requestBody | string | false | The request body payload. Use [go template][go-template-doc] with the parameter name as the placeholder (e.g., `{{.id}}` will be replaced with the value of the parameter that has name `id` in the `bodyParams` section). |
| queryParams | [parameters](_index#specifying-parameters) | false | List of [parameters](_index#specifying-parameters) that will be inserted into the query string. |
| bodyParams | [parameters](_index#specifying-parameters) | false | List of [parameters](_index#specifying-parameters) that will be inserted into the request body payload. |
| headerParams | [parameters](_index#specifying-parameters) | false | List of [parameters](_index#specifying-parameters) that will be inserted as the request headers. |
[go-template-doc]: <https://pkg.go.dev/text/template#pkg-overview>

View File

@@ -27,6 +27,7 @@ import (
cloudsqlmysqlsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
dgraphsrc "github.com/googleapis/genai-toolbox/internal/sources/dgraph"
httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http"
mssqlsrc "github.com/googleapis/genai-toolbox/internal/sources/mssql"
mysqlsrc "github.com/googleapis/genai-toolbox/internal/sources/mysql"
neo4jrc "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
@@ -34,6 +35,7 @@ import (
spannersrc "github.com/googleapis/genai-toolbox/internal/sources/spanner"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/dgraph"
httptool "github.com/googleapis/genai-toolbox/internal/tools/http"
"github.com/googleapis/genai-toolbox/internal/tools/mssqlsql"
"github.com/googleapis/genai-toolbox/internal/tools/mysqlsql"
neo4jtool "github.com/googleapis/genai-toolbox/internal/tools/neo4j"
@@ -212,6 +214,12 @@ func (c *SourceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interf
return fmt.Errorf("unable to parse as %q: %w", kind, err)
}
(*c)[name] = actual
case httpsrc.SourceKind:
actual := httpsrc.DefaultConfig(name)
if err := dec.DecodeContext(ctx, &actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of data source", kind)
}
@@ -329,6 +337,12 @@ func (c *ToolConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interfac
return fmt.Errorf("unable to parse as %q: %w", kind, err)
}
(*c)[name] = actual
case httptool.ToolKind:
actual := httptool.Config{Name: name}
if err := dec.DecodeContext(ctx, &actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of tool", kind)
}

View File

@@ -0,0 +1,91 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package http
import (
"context"
"fmt"
"net/http"
"net/url"
"time"
"github.com/googleapis/genai-toolbox/internal/sources"
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "http"
// validate interface
var _ sources.SourceConfig = Config{}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
BaseURL string `yaml:"baseUrl"`
Timeout string `yaml:"timeout"`
DefaultHeaders map[string]string `yaml:"headers"`
QueryParams map[string]string `yaml:"queryParams"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
}
// DefaultConfig is a helper function that generates the default configuration for an HTTP Tool Config.
func DefaultConfig(name string) Config {
return Config{Name: name, Timeout: "30s"}
}
// Initialize initializes an HTTP Source instance.
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
duration, err := time.ParseDuration(r.Timeout)
if err != nil {
return nil, fmt.Errorf("unable to parse Timeout string as time.Duration: %s", err)
}
client := http.Client{
Timeout: duration,
}
// Validate BaseURL
_, err = url.ParseRequestURI(r.BaseURL)
if err != nil {
return nil, fmt.Errorf("failed to parse BaseUrl %v", err)
}
s := &Source{
Name: r.Name,
Kind: SourceKind,
BaseURL: r.BaseURL,
DefaultHeaders: r.DefaultHeaders,
QueryParams: r.QueryParams,
Client: &client,
}
return s, nil
}
var _ sources.Source = &Source{}
type Source struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
BaseURL string `yaml:"baseUrl"`
DefaultHeaders map[string]string `yaml:"headers"`
QueryParams map[string]string `yaml:"queryParams"`
Client *http.Client
}
func (s *Source) SourceKind() string {
return SourceKind
}

View File

@@ -0,0 +1,128 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package http_test
import (
"strings"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/http"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
func TestParseFromYamlHttp(t *testing.T) {
tcs := []struct {
desc string
in string
want server.SourceConfigs
}{
{
desc: "basic example",
in: `
sources:
my-http-instance:
kind: http
baseUrl: http://test_server/
timeout: 10s
headers:
Authorization: test_header
Custom-Header: custom
queryParams:
api-key: test_api_key
param: param-value
`,
want: map[string]sources.SourceConfig{
"my-http-instance": http.Config{
Name: "my-http-instance",
Kind: http.SourceKind,
BaseURL: "http://test_server/",
Timeout: "10s",
DefaultHeaders: map[string]string{"Authorization": "test_header", "Custom-Header": "custom"},
QueryParams: map[string]string{"api-key": "test_api_key", "param": "param-value"},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
}
func TestFailParseFromYaml(t *testing.T) {
tcs := []struct {
desc string
in string
err string
}{
{
desc: "extra field",
in: `
sources:
my-http-instance:
kind: http
baseUrl: http://test_server/
timeout: 10s
headers:
Authorization: test_header
queryParams:
api-key: test_api_key
project: test-project
`,
err: "unable to parse as \"http\"",
},
{
desc: "missing required field",
in: `
sources:
my-http-instance:
baseUrl: http://test_server/
`,
err: "missing 'kind' field for \"my-http-instance\"",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}
errStr := err.Error()
if !strings.Contains(errStr, tc.err) {
t.Fatalf("unexpected error string: got %q, want substring %q", errStr, tc.err)
}
})
}
}

276
internal/tools/http/http.go Normal file
View File

@@ -0,0 +1,276 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package http
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"slices"
"strings"
"maps"
"text/template"
"github.com/googleapis/genai-toolbox/internal/sources"
httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http"
"github.com/googleapis/genai-toolbox/internal/tools"
)
const ToolKind string = "http"
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description" validate:"required"`
AuthRequired []string `yaml:"authRequired"`
Path string `yaml:"path" validate:"required"`
Method tools.HTTPMethod `yaml:"method" validate:"required"`
Headers map[string]string `yaml:"headers"`
RequestBody string `yaml:"requestBody"`
QueryParams tools.Parameters `yaml:"queryParams"`
BodyParams tools.Parameters `yaml:"bodyParams"`
HeaderParams tools.Parameters `yaml:"headerParams"`
}
// validate interface
var _ tools.ToolConfig = Config{}
func (cfg Config) ToolConfigKind() string {
return ToolKind
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(*httpsrc.Source)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `http`", ToolKind)
}
// Create URL based on BaseURL and Path
// Attach query parameters
u, err := url.Parse(s.BaseURL + cfg.Path)
if err != nil {
return nil, fmt.Errorf("error parsing URL: %s", err)
}
// Get existing query parameters from the URL
queryParameters := u.Query()
for key, value := range s.QueryParams {
queryParameters.Add(key, value)
}
u.RawQuery = queryParameters.Encode()
// Combine Source and Tool headers.
// In case of conflict, Tool header overrides Source header
combinedHeaders := make(map[string]string)
maps.Copy(combinedHeaders, s.DefaultHeaders)
maps.Copy(combinedHeaders, cfg.Headers)
// Create a slice for all parameters
allParameters := slices.Concat(cfg.BodyParams, cfg.HeaderParams, cfg.QueryParams)
// Create parameter manifest
paramManifest := slices.Concat(
cfg.QueryParams.Manifest(),
cfg.BodyParams.Manifest(),
cfg.HeaderParams.Manifest(),
)
if paramManifest == nil {
paramManifest = make([]tools.ParameterManifest, 0)
}
// Verify there are no duplicate parameter names
seenNames := make(map[string]bool)
for _, param := range paramManifest {
if _, exists := seenNames[param.Name]; exists {
return nil, fmt.Errorf("parameter name must be unique across queryParams, bodyParams, and headerParams. Duplicate parameter: %s", param.Name)
}
seenNames[param.Name] = true
}
// finish tool setup
return Tool{
Name: cfg.Name,
Kind: ToolKind,
URL: u,
Method: cfg.Method,
AuthRequired: cfg.AuthRequired,
RequestBody: cfg.RequestBody,
QueryParams: cfg.QueryParams,
BodyParams: cfg.BodyParams,
HeaderParams: cfg.HeaderParams,
Headers: combinedHeaders,
Client: s.Client,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest},
}, nil
}
// validate interface
var _ tools.Tool = Tool{}
type Tool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Description string `yaml:"description"`
AuthRequired []string `yaml:"authRequired"`
URL *url.URL `yaml:"url"`
Method tools.HTTPMethod `yaml:"method"`
Headers map[string]string `yaml:"headers"`
RequestBody string `yaml:"requestBody"`
QueryParams tools.Parameters `yaml:"queryParams"`
BodyParams tools.Parameters `yaml:"bodyParams"`
HeaderParams tools.Parameters `yaml:"headerParams"`
AllParams tools.Parameters `yaml:"allParams"`
Client *http.Client
manifest tools.Manifest
}
// helper function to convert a parameter to JSON formatted string.
func convertParamToJSON(param any) (string, error) {
jsonData, err := json.Marshal(param)
if err != nil {
return "", fmt.Errorf("failed to marshal param to JSON: %w", err)
}
return string(jsonData), nil
}
// Helper function to generate the HTTP request body upon Tool invocation.
func getRequestBody(bodyParams tools.Parameters, requestBodyPayload string, paramsMap map[string]any) (string, error) {
// Create a map for request body parameters
bodyParamsMap := make(map[string]any)
for _, p := range bodyParams {
k := p.GetName()
v, ok := paramsMap[k]
if !ok {
return "", fmt.Errorf("missing request body parameter %s", k)
}
bodyParamsMap[k] = v
}
// Create a FuncMap to format array parameters
funcMap := template.FuncMap{
"json": convertParamToJSON,
}
templ, err := template.New("body").Funcs(funcMap).Parse(requestBodyPayload)
if err != nil {
return "", fmt.Errorf("error parsing request body: %s", err)
}
var result bytes.Buffer
err = templ.Execute(&result, bodyParamsMap)
if err != nil {
return "", fmt.Errorf("error replacing body payload: %s", err)
}
return result.String(), nil
}
// Helper function to generate the HTTP request URL upon Tool invocation.
func getURL(u *url.URL, queryParams tools.Parameters, paramsMap map[string]any) (string, error) {
// Set dynamic query parameters
query := u.Query()
for _, p := range queryParams {
query.Add(p.GetName(), fmt.Sprintf("%v", paramsMap[p.GetName()]))
}
u.RawQuery = query.Encode()
return u.String(), nil
}
// Helper function to generate the HTTP headers upon Tool invocation.
func getHeaders(headerParams tools.Parameters, defaultHeaders map[string]string, paramsMap map[string]any) (map[string]string, error) {
// Populate header params
allHeaders := make(map[string]string)
maps.Copy(allHeaders, defaultHeaders)
for _, p := range headerParams {
headerValue, ok := paramsMap[p.GetName()]
if ok {
if strValue, ok := headerValue.(string); ok {
allHeaders[p.GetName()] = strValue
} else {
return nil, fmt.Errorf("header param %s got value of type %t, not string", p.GetName(), headerValue)
}
}
}
return allHeaders, nil
}
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
paramsMap := params.AsMap()
// Calculate request body
requestBody, err := getRequestBody(t.BodyParams, t.RequestBody, paramsMap)
if err != nil {
return nil, fmt.Errorf("error populating request body: %s", err)
}
// Calculate URL
urlString, err := getURL(t.URL, t.QueryParams, paramsMap)
if err != nil {
return nil, fmt.Errorf("error populating query parameters: %s", err)
}
req, _ := http.NewRequest(string(t.Method), urlString, strings.NewReader(requestBody))
// Calculate request headers
allHeaders, err := getHeaders(t.HeaderParams, t.Headers, paramsMap)
if err != nil {
return nil, fmt.Errorf("error populating request headers: %s", err)
}
// Set request headers
for k, v := range allHeaders {
req.Header.Set(k, v)
}
// Make request and fetch response
resp, err := t.Client.Do(req)
if err != nil {
return nil, fmt.Errorf("error making HTTP request: %s", err)
}
defer resp.Body.Close()
var body []byte
body, err = io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body))
}
return []any{string(body)}, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
return tools.ParseParams(t.AllParams, data, claims)
}
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

View File

@@ -0,0 +1,199 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package http_test
import (
"strings"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools"
http "github.com/googleapis/genai-toolbox/internal/tools/http"
)
func TestParseFromYamlHTTP(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
want server.ToolConfigs
}{
{
desc: "basic example",
in: `
tools:
example_tool:
kind: http
source: my-instance
method: GET
path: "search?name=alice&pet=cat"
description: some description
authRequired:
- my-google-auth-service
- other-auth-service
queryParams:
- name: country
type: string
description: some description
authServices:
- name: my-google-auth-service
field: user_id
- name: other-auth-service
field: user_id
requestBody: |
{
"age": {{.age}}
"city": "{{.city}}"
"food": {{.food}}
}
bodyParams:
- name: age
type: integer
description: age num
- name: city
type: string
description: city string
headers:
Authorization: API_KEY
Content-Type: application/json
headerParams:
- name: Language
type: string
description: language string
`,
want: server.ToolConfigs{
"example_tool": http.Config{
Name: "example_tool",
Kind: http.ToolKind,
Source: "my-instance",
Method: "GET",
Path: "search?name=alice&pet=cat",
Description: "some description",
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
QueryParams: []tools.Parameter{
tools.NewStringParameterWithAuth("country", "some description",
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"},
{Name: "other-auth-service", Field: "user_id"}}),
},
RequestBody: `{
"age": {{.age}}
"city": "{{.city}}"
"food": {{.food}}
}
`,
BodyParams: []tools.Parameter{tools.NewIntParameter("age", "age num"), tools.NewStringParameter("city", "city string")},
Headers: map[string]string{"Authorization": "API_KEY", "Content-Type": "application/json"},
HeaderParams: []tools.Parameter{tools.NewStringParameter("Language", "language string")},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
}
}
func TestFailParseFromYamlHTTP(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
err string
}{
{
desc: "Invalid method",
in: `
tools:
example_tool:
kind: http
source: my-instance
method: GOT
path: "search?name=alice&pet=cat"
description: some description
authRequired:
- my-google-auth-service
- other-auth-service
queryParams:
- name: country
type: string
description: some description
authServices:
- name: my-google-auth-service
field: user_id
- name: other-auth-service
field: user_id
requestBody: |
{
"age": {{.age}},
"city": "{{.city}}",
}
bodyParams:
- name: age
type: integer
description: age num
- name: city
type: string
description: city string
headers:
Authorization: API_KEY
Content-Type: application/json
headerParams:
- name: Language
type: string
description: language string
`,
err: `GOT is not a valid http method`,
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}
errStr := err.Error()
if !strings.Contains(errStr, tc.err) {
t.Fatalf("unexpected error string: got %q, want substring %q", errStr, tc.err)
}
})
}
}

View File

@@ -0,0 +1,49 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tools
import (
"context"
"fmt"
"net/http"
"strings"
)
// HTTPMethod is a string of a valid HTTP method (e.g "GET")
type HTTPMethod string
// isValidHTTPMethod checks if the input string matches one of the method constants defined in the net/http package
func isValidHTTPMethod(method string) bool {
switch method {
case http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete,
http.MethodPatch, http.MethodHead, http.MethodOptions, http.MethodTrace,
http.MethodConnect:
return true
}
return false
}
func (i *HTTPMethod) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
var httpMethod string
if err := unmarshal(&httpMethod); err != nil {
return fmt.Errorf(`error unmarshalling HTTP method: %s`, err)
}
httpMethod = strings.ToUpper(httpMethod)
if !isValidHTTPMethod(httpMethod) {
return fmt.Errorf(`%s is not a valid http method`, httpMethod)
}
*i = HTTPMethod(httpMethod)
return nil
}

View File

@@ -122,7 +122,7 @@ func ParseParams(ps Parameters, data map[string]any, claimsMap map[string]map[st
var v any
paramAuthServices := p.GetAuthServices()
name := p.GetName()
if paramAuthServices == nil {
if len(paramAuthServices) == 0 {
// parse non auth-required parameter
var ok bool
v, ok = data[name]
@@ -567,8 +567,8 @@ func (p *ArrayParameter) UnmarshalYAML(ctx context.Context, unmarshal func(inter
if err != nil {
return fmt.Errorf("unable to parse 'items' field: %w", err)
}
if i.GetAuthServices() != nil {
return fmt.Errorf("nested items should not have auth services.")
if i.GetAuthServices() != nil && len(i.GetAuthServices()) != 0 {
return fmt.Errorf("nested items should not have auth services")
}
p.Items = i

View File

@@ -22,6 +22,8 @@ package tests
import (
"database/sql"
"fmt"
"github.com/googleapis/genai-toolbox/internal/tools"
)
// GetToolsConfig returns a mock tools config file
@@ -97,6 +99,97 @@ func GetToolsConfig(sourceConfig map[string]any, toolKind, param_tool_statement,
return toolsFile
}
// GetHTTPToolsConfig returns a mock HTTP tool's config file
func GetHTTPToolsConfig(sourceConfig map[string]any, toolKind string) map[string]any {
// Write config into a file and pass it to command
otherSourceConfig := make(map[string]any)
for k, v := range sourceConfig {
otherSourceConfig[k] = v
}
otherSourceConfig["headers"] = map[string]string{"X-Custom-Header": "unexpected", "Content-Type": "application/json"}
otherSourceConfig["queryParams"] = map[string]any{"id": 1, "name": "Sid"}
toolsFile := map[string]any{
"sources": map[string]any{
"my-instance": sourceConfig,
"other-instance": otherSourceConfig,
},
"authServices": map[string]any{
"my-google-auth": map[string]any{
"kind": "google",
"clientId": ClientId,
},
},
"tools": map[string]any{
"my-simple-tool": map[string]any{
"kind": toolKind,
"path": "/tool0",
"method": "POST",
"source": "my-instance",
"requestBody": "{}",
"description": "Simple tool to test end to end functionality.",
},
"my-param-tool": map[string]any{
"kind": toolKind,
"source": "my-instance",
"method": "GET",
"path": "/tool1",
"description": "some description",
"queryParams": []tools.Parameter{
tools.NewIntParameter("id", "user ID")},
"requestBody": `{
"age": 36,
"name": "{{.name}}"
}
`,
"bodyParams": []tools.Parameter{tools.NewStringParameter("name", "user name")},
"headers": map[string]string{"Content-Type": "application/json"},
},
"my-auth-tool": map[string]any{
"kind": toolKind,
"source": "my-instance",
"method": "GET",
"path": "/tool2",
"description": "some description",
"requestBody": "{}",
"queryParams": []tools.Parameter{
tools.NewStringParameterWithAuth("email", "some description",
[]tools.ParamAuthService{{Name: "my-google-auth", Field: "email"}}),
},
},
"my-auth-required-tool": map[string]any{
"kind": toolKind,
"source": "my-instance",
"method": "POST",
"path": "/tool0",
"description": "some description",
"requestBody": "{}",
"authRequired": []string{"my-google-auth"},
},
"my-advanced-tool": map[string]any{
"kind": toolKind,
"source": "other-instance",
"method": "get",
"path": "/tool3?id=2",
"description": "some description",
"headers": map[string]string{
"X-Custom-Header": "example",
},
"queryParams": []tools.Parameter{
tools.NewIntParameter("id", "user ID"), tools.NewStringParameter("country", "country")},
"requestBody": `{
"place": "zoo",
"animals": {{json .animalArray }}
}
`,
"bodyParams": []tools.Parameter{tools.NewArrayParameter("animalArray", "animals in the zoo", tools.NewStringParameter("animals", "desc"))},
"headerParams": []tools.Parameter{tools.NewStringParameter("X-Other-Header", "custom header")},
},
},
}
return toolsFile
}
// GetPostgresSQLParamToolInfo returns statements and param for my-param-tool postgres-sql kind
func GetPostgresSQLParamToolInfo(tableName string) (string, string, string, []any) {
create_statement := fmt.Sprintf("CREATE TABLE %s (id SERIAL PRIMARY KEY, name TEXT);", tableName)

335
tests/http_test.go Normal file
View File

@@ -0,0 +1,335 @@
//go:build integration && http
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tests
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"reflect"
"regexp"
"strings"
"testing"
"time"
)
var (
HTTP_SOURCE_KIND = "http"
HTTP_TOOL_KIND = "http"
)
func getHTTPSourceConfig(t *testing.T) map[string]any {
idToken, err := GetGoogleIdToken(ClientId)
if err != nil {
t.Fatalf("error getting ID token: %s", err)
}
idToken = "Bearer " + idToken
return map[string]any{
"kind": HTTP_SOURCE_KIND,
"headers": map[string]string{"Authorization": idToken},
}
}
// handler function for the test server
func multiTool(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
path = strings.TrimPrefix(path, "/") // Remove leading slash
switch path {
case "tool0":
handleTool0(w, r)
case "tool1":
handleTool1(w, r)
case "tool2":
handleTool2(w, r)
case "tool3":
handleTool3(w, r)
default:
http.NotFound(w, r) // Return 404 for unknown paths
}
}
// handler function for the test server
func handleTool0(w http.ResponseWriter, r *http.Request) {
// expect POST method
if r.Method != http.MethodPost {
errorMessage := fmt.Sprintf("expected POST method but got: %s", string(r.Method))
http.Error(w, errorMessage, http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
response := []string{
"Hello",
"World",
}
err := json.NewEncoder(w).Encode(response)
if err != nil {
http.Error(w, "Failed to encode JSON", http.StatusInternalServerError)
return
}
}
// handler function for the test server
func handleTool1(w http.ResponseWriter, r *http.Request) {
// expect GET method
if r.Method != http.MethodGet {
errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method))
http.Error(w, errorMessage, http.StatusBadRequest)
return
}
// Parse request body
var requestBody map[string]interface{}
bodyBytes, readErr := io.ReadAll(r.Body)
if readErr != nil {
http.Error(w, "Bad Request: Failed to read request body", http.StatusBadRequest)
return
}
defer r.Body.Close()
err := json.Unmarshal(bodyBytes, &requestBody)
if err != nil {
errorMessage := fmt.Sprintf("Bad Request: Error unmarshalling request body: %s, Raw body: %s", err, string(bodyBytes))
http.Error(w, errorMessage, http.StatusBadRequest)
return
}
// Extract name
name, ok := requestBody["name"].(string)
if !ok || name == "" {
http.Error(w, "Bad Request: Missing or invalid name", http.StatusBadRequest)
return
}
if name == "Alice" {
response := `{"id":1,"name":"Alice"},{"id":3,"name":"Sid"}`
_, err := w.Write([]byte(response))
if err != nil {
http.Error(w, "Failed to write response", http.StatusInternalServerError)
}
return
}
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
// handler function for the test server
func handleTool2(w http.ResponseWriter, r *http.Request) {
// expect GET method
if r.Method != http.MethodGet {
errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method))
http.Error(w, errorMessage, http.StatusBadRequest)
return
}
email := r.URL.Query().Get("email")
if email != "" {
response := `{"name":"Alice"}`
_, err := w.Write([]byte(response))
if err != nil {
http.Error(w, "Failed to write response", http.StatusInternalServerError)
}
return
}
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
// handler function for the test server
func handleTool3(w http.ResponseWriter, r *http.Request) {
// expect GET method
if r.Method != http.MethodGet {
errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method))
http.Error(w, errorMessage, http.StatusBadRequest)
return
}
// Check request headers
expectedHeaders := map[string]string{
"Content-Type": "application/json",
"X-Custom-Header": "example",
"X-Other-Header": "test",
}
for header, expectedValue := range expectedHeaders {
if r.Header.Get(header) != expectedValue {
errorMessage := fmt.Sprintf("Bad Request: Missing or incorrect header: %s", header)
http.Error(w, errorMessage, http.StatusBadRequest)
return
}
}
// Check query parameters
expectedQueryParams := map[string][]string{
"id": []string{"2", "1", "3"},
"country": []string{"US"},
}
query := r.URL.Query()
for param, expectedValueSlice := range expectedQueryParams {
values, ok := query[param]
if ok {
if !reflect.DeepEqual(expectedValueSlice, values) {
errorMessage := fmt.Sprintf("Bad Request: Incorrect query parameter: %s, actual: %s", param, query[param])
http.Error(w, errorMessage, http.StatusBadRequest)
return
}
} else {
errorMessage := fmt.Sprintf("Bad Request: Missing query parameter: %s, actual: %s", param, query[param])
http.Error(w, errorMessage, http.StatusBadRequest)
return
}
}
// Parse request body
var requestBody map[string]interface{}
bodyBytes, readErr := io.ReadAll(r.Body)
if readErr != nil {
http.Error(w, "Bad Request: Failed to read request body", http.StatusBadRequest)
return
}
defer r.Body.Close()
err := json.Unmarshal(bodyBytes, &requestBody)
if err != nil {
errorMessage := fmt.Sprintf("Bad Request: Error unmarshalling request body: %s, Raw body: %s", err, string(bodyBytes))
http.Error(w, errorMessage, http.StatusBadRequest)
return
}
// Check request body
expectedBody := map[string]interface{}{
"place": "zoo",
"animals": []any{"rabbit", "ostrich", "whale"},
}
if !reflect.DeepEqual(requestBody, expectedBody) {
errorMessage := fmt.Sprintf("Bad Request: Incorrect request body. Expected: %v, Got: %v", expectedBody, requestBody)
http.Error(w, errorMessage, http.StatusBadRequest)
return
}
// Return a JSON array as the response
response := []any{
"Hello", "World",
}
err = json.NewEncoder(w).Encode(response)
if err != nil {
http.Error(w, "Failed to encode JSON", http.StatusInternalServerError)
return
}
}
func TestToolEndpoints(t *testing.T) {
// start a test server
server := httptest.NewServer(http.HandlerFunc(multiTool))
defer server.Close()
sourceConfig := getHTTPSourceConfig(t)
sourceConfig["baseUrl"] = server.URL
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var args []string
toolsFile := GetHTTPToolsConfig(sourceConfig, HTTP_TOOL_KIND)
cmd, cleanup, err := StartCmd(ctx, toolsFile, args...)
if err != nil {
t.Fatalf("command initialization returned an error: %s", err)
}
defer cleanup()
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
if err != nil {
t.Logf("toolbox command logs: \n%s", out)
t.Fatalf("toolbox didn't start successfully: %s", err)
}
select_1_want := `["[\"Hello\",\"World\"]\n"]`
RunToolGetTest(t)
RunToolInvokeTest(t, select_1_want)
RunAdvancedHTTPInvokeTest(t)
}
// RunToolInvoke runs the tool invoke endpoint
func RunAdvancedHTTPInvokeTest(t *testing.T) {
// Test HTTP tool invoke endpoint
invokeTcs := []struct {
name string
api string
requestHeader map[string]string
requestBody io.Reader
want string
isErr bool
}{
{
name: "invoke my-advanced-tool",
api: "http://127.0.0.1:5000/api/tool/my-advanced-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"animalArray": ["rabbit", "ostrich", "whale"], "id": 3, "country": "US", "X-Other-Header": "test"}`)),
want: `["[\"Hello\",\"World\"]\n"]`,
isErr: false,
},
{
name: "invoke my-advanced-tool with wrong params",
api: "http://127.0.0.1:5000/api/tool/my-advanced-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"animalArray": ["rabbit", "ostrich", "whale"], "id": 4, "country": "US", "X-Other-Header": "test"}`)),
isErr: true,
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
// Send Tool invocation request
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
for k, v := range tc.requestHeader {
req.Header.Add(k, v)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if tc.isErr == true {
return
}
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
// Check response body
var body map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&body)
if err != nil {
t.Fatalf("error parsing response body")
}
got, ok := body["result"].(string)
if !ok {
t.Fatalf("unable to find result in response body")
}
if got != tc.want {
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
}
})
}
}

View File

@@ -25,6 +25,7 @@ import (
"io"
"net/http"
"reflect"
"strings"
"testing"
"github.com/jackc/pgx/v5/pgxpool"
@@ -287,12 +288,17 @@ func RunToolInvokeTest(t *testing.T, select_1_want string) {
if err != nil {
t.Fatalf("error parsing response body")
}
got, ok := body["result"].(string)
if !ok {
t.Fatalf("unable to find result in response body")
}
if got != tc.want {
// Remove `\` and `"` for string comparison
got = strings.ReplaceAll(got, "\\", "")
want := strings.ReplaceAll(tc.want, "\\", "")
got = strings.ReplaceAll(got, "\"", "")
want = strings.ReplaceAll(want, "\"", "")
if got != want {
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
}
})