mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 16:08:16 -05:00
feat: Add HTTP Source and Tool (#332)
Add Source and Tool for tool invocation through HTTP requests.
This commit is contained in:
@@ -214,6 +214,21 @@ steps:
|
||||
- |
|
||||
go test -race -v -tags=integration,dgraph ./tests
|
||||
|
||||
- id: "http"
|
||||
name: golang:1
|
||||
waitFor: ["install-dependencies"]
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
secretEnv: ["CLIENT_ID"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
go test -race -v -tags=integration,http ./tests
|
||||
|
||||
availableSecrets:
|
||||
secretManager:
|
||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
|
||||
|
||||
@@ -43,3 +43,4 @@ run:
|
||||
- dgraph
|
||||
- mssql
|
||||
- mysql
|
||||
- http
|
||||
|
||||
42
docs/en/resources/sources/http.md
Normal file
42
docs/en/resources/sources/http.md
Normal file
@@ -0,0 +1,42 @@
|
||||
---
|
||||
title: "HTTP"
|
||||
linkTitle: "HTTP"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
The HTTP source enables the Toolbox to retrieve data from a remote server using HTTP requests.
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
The HTTP Source allows the Gen AI Toolbox to retrieve data from arbitrary HTTP
|
||||
endpoints. This enables Generative AI applications to access data from web APIs
|
||||
and other HTTP-accessible resources.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-http-source:
|
||||
kind: http
|
||||
baseUrl: https://api.example.com/data
|
||||
timeout: 10s # default to 30s
|
||||
headers:
|
||||
Authorization: Bearer YOUR_API_TOKEN
|
||||
Content-Type: application/json
|
||||
queryParams:
|
||||
param1: value1
|
||||
param2: value2
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:-----------------:|:------------:|-----------------------------------------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "http". |
|
||||
| baseUrl | string | true | The base URL for the HTTP requests (e.g., `https://api.example.com`). |
|
||||
| timeout | string | false | The timeout for HTTP requests (e.g., "5s", "1m", refer to [ParseDuration][parse-duration-doc] for more examples). Defaults to 30s. |
|
||||
| headers | map[string]string | false | Default headers to include in the HTTP requests. |
|
||||
| queryParams | map[string]string | false | Default query parameters to include in the HTTP requests. |
|
||||
|
||||
[parse-duration-doc]: https://pkg.go.dev/time#ParseDuration
|
||||
225
docs/en/resources/tools/http.md
Normal file
225
docs/en/resources/tools/http.md
Normal file
@@ -0,0 +1,225 @@
|
||||
---
|
||||
title: "http"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "http" tool sends out an HTTP request to an HTTP endpoint.
|
||||
---
|
||||
|
||||
|
||||
## About
|
||||
|
||||
The `http` tool allows you to make HTTP requests to APIs to retrieve data.
|
||||
An HTTP request is the method by which a client communicates with a server to retrieve or manipulate resources.
|
||||
Toolbox allows you to configure the request URL, method, headers, query parameters, and the request body for an HTTP Tool.
|
||||
|
||||
### URL
|
||||
|
||||
An HTTP request URL identifies the target the client wants to access.
|
||||
Toolbox composes the request URL from the HTTP Source's `baseUrl` and the HTTP Tool's `path`.
|
||||
For example, the following config allows you to reach different paths of the same server using multiple Tools:
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-http-source:
|
||||
kind: http
|
||||
baseUrl: https://api.example.com
|
||||
|
||||
tools:
|
||||
my-post-tool:
|
||||
kind: http
|
||||
source: my-http-source
|
||||
method: POST
|
||||
path: /update
|
||||
description: Tool to update information to the example API
|
||||
|
||||
my-get-tool:
|
||||
kind: http
|
||||
source: my-http-source
|
||||
method: GET
|
||||
path: /search
|
||||
description: Tool to search information from the example API
|
||||
|
||||
```
|
||||
|
||||
### Headers
|
||||
|
||||
An HTTP request header is a key-value pair sent by a client to a server, providing additional information about the request, such as the client's preferences, the request body content type, and other metadata.
|
||||
Headers specified by the HTTP Tool are combined with its HTTP Source headers for the resulting HTTP request, and override the Source headers in case of conflict.
|
||||
The HTTP Tool allows you to specify headers in two different ways:
|
||||
|
||||
- Static headers can be specified using the `headers` field, and will be the same for every invocation:
|
||||
|
||||
```yaml
|
||||
my-http-tool:
|
||||
kind: http
|
||||
source: my-http-source
|
||||
method: GET
|
||||
path: /search
|
||||
description: Tool to search data from API
|
||||
headers:
|
||||
Authorization: API_KEY
|
||||
Content-Type: application/json
|
||||
```
|
||||
|
||||
- Dynamic headers can be specified as parameters in the `headerParams` field. The `name` of the `headerParams` will be used as the header key, and the value is determined by the LLM input upon Tool invocation:
|
||||
|
||||
```yaml
|
||||
my-http-tool:
|
||||
kind: http
|
||||
source: my-http-source
|
||||
method: GET
|
||||
path: /search
|
||||
description: some description
|
||||
headerParams:
|
||||
- name: Content-Type # Example LLM input: "application/json"
|
||||
description: request content type
|
||||
type: string
|
||||
```
|
||||
|
||||
### Query parameters
|
||||
|
||||
Query parameters are key-value pairs appended to a URL after a question mark (?) to provide additional information to the server for processing the request, like filtering or sorting data.
|
||||
|
||||
- Static request query parameters should be specified in the `path` as part of the URL itself:
|
||||
|
||||
```yaml
|
||||
my-http-tool:
|
||||
kind: http
|
||||
source: my-http-source
|
||||
method: GET
|
||||
path: /search?language=en&id=1
|
||||
description: Tool to search for item with ID 1 in English
|
||||
```
|
||||
|
||||
- Dynamic request query parameters should be specified as parameters in the `queryParams` section:
|
||||
|
||||
```yaml
|
||||
my-http-tool:
|
||||
kind: http
|
||||
source: my-http-source
|
||||
method: GET
|
||||
path: /search
|
||||
description: Tool to search for item with ID
|
||||
queryParams:
|
||||
- name: id
|
||||
description: item ID
|
||||
type: integer
|
||||
```
|
||||
|
||||
### Request body
|
||||
|
||||
The request body payload is a string that supports parameter replacement with following [Go template][go-template-doc]'s annotations.
|
||||
The parameter names in the `requestBody` should be preceded by "." and enclosed by double curly brackets "{{}}". The values will be populated into the request body payload upon Tool invocation.
|
||||
|
||||
Example:
|
||||
|
||||
```yaml
|
||||
my-http-tool:
|
||||
kind: http
|
||||
source: my-http-source
|
||||
method: GET
|
||||
path: /search
|
||||
description: Tool to search for person with name and age
|
||||
requestBody: |
|
||||
{
|
||||
"age": {{.age}},
|
||||
"name": "{{.name}}"
|
||||
}
|
||||
bodyParams:
|
||||
- name: age
|
||||
description: age number
|
||||
type: integer
|
||||
- name: name
|
||||
description: name string
|
||||
type: string
|
||||
```
|
||||
|
||||
#### Formatting Parameters
|
||||
|
||||
Some complex parameters (such as arrays) may require additional formatting to match the expected output. For convenience, you can specify one of the following pre-defined functions before the parameter name to format it:
|
||||
|
||||
##### JSON
|
||||
|
||||
The `json` keyword converts a parameter into a JSON format.
|
||||
|
||||
{{< notice note >}}
|
||||
Using JSON may add quotes to the variable name for certain types (such as strings).
|
||||
{{< /notice >}}
|
||||
|
||||
Example:
|
||||
|
||||
```yaml
|
||||
requestBody: |
|
||||
{
|
||||
"age": {{json .age}},
|
||||
"name": {{json .name}},
|
||||
"nickname": "{{json .nickname}}",
|
||||
"nameArray": {{json .nameArray}}
|
||||
}
|
||||
```
|
||||
|
||||
will send the following output:
|
||||
|
||||
```yaml
|
||||
{
|
||||
"age": 18,
|
||||
"name": "Katherine",
|
||||
"nickname": ""Kat"", # Duplicate quotes
|
||||
"nameArray": ["A", "B", "C"]
|
||||
}
|
||||
```
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
my-http-tool:
|
||||
kind: http
|
||||
source: my-http-source
|
||||
method: GET
|
||||
path: /search
|
||||
description: some description
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
queryParams:
|
||||
- name: country
|
||||
description: some description
|
||||
type: string
|
||||
requestBody: |
|
||||
{
|
||||
"age": {{.age}},
|
||||
"city": "{{.city}}"
|
||||
}
|
||||
bodyParams:
|
||||
- name: age
|
||||
description: age number
|
||||
type: integer
|
||||
- name: city
|
||||
description: city string
|
||||
type: string
|
||||
headers:
|
||||
Authorization: API_KEY
|
||||
Content-Type: application/json
|
||||
headerParams:
|
||||
- name: Language
|
||||
description: language string
|
||||
type: string
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|--------------|:------------------------------------------:|:------------:|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "http". |
|
||||
| source | string | true | Name of the source the HTTP request should be sent to. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
| path | string | true | The path of the HTTP request. You can include static query parameters in the path string. |
|
||||
| method | string | true | The HTTP method to use (e.g., GET, POST, PUT, DELETE). |
|
||||
| headers | map[string]string | false | A map of headers to include in the HTTP request (overrides source headers). |
|
||||
| requestBody | string | false | The request body payload. Use [go template][go-template-doc] with the parameter name as the placeholder (e.g., `{{.id}}` will be replaced with the value of the parameter that has name `id` in the `bodyParams` section). |
|
||||
| queryParams | [parameters](_index#specifying-parameters) | false | List of [parameters](_index#specifying-parameters) that will be inserted into the query string. |
|
||||
| bodyParams | [parameters](_index#specifying-parameters) | false | List of [parameters](_index#specifying-parameters) that will be inserted into the request body payload. |
|
||||
| headerParams | [parameters](_index#specifying-parameters) | false | List of [parameters](_index#specifying-parameters) that will be inserted as the request headers. |
|
||||
|
||||
[go-template-doc]: <https://pkg.go.dev/text/template#pkg-overview>
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
cloudsqlmysqlsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
|
||||
cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||
dgraphsrc "github.com/googleapis/genai-toolbox/internal/sources/dgraph"
|
||||
httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http"
|
||||
mssqlsrc "github.com/googleapis/genai-toolbox/internal/sources/mssql"
|
||||
mysqlsrc "github.com/googleapis/genai-toolbox/internal/sources/mysql"
|
||||
neo4jrc "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
|
||||
@@ -34,6 +35,7 @@ import (
|
||||
spannersrc "github.com/googleapis/genai-toolbox/internal/sources/spanner"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/dgraph"
|
||||
httptool "github.com/googleapis/genai-toolbox/internal/tools/http"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mssqlsql"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mysqlsql"
|
||||
neo4jtool "github.com/googleapis/genai-toolbox/internal/tools/neo4j"
|
||||
@@ -212,6 +214,12 @@ func (c *SourceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interf
|
||||
return fmt.Errorf("unable to parse as %q: %w", kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
case httpsrc.SourceKind:
|
||||
actual := httpsrc.DefaultConfig(name)
|
||||
if err := dec.DecodeContext(ctx, &actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of data source", kind)
|
||||
}
|
||||
@@ -329,6 +337,12 @@ func (c *ToolConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interfac
|
||||
return fmt.Errorf("unable to parse as %q: %w", kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
case httptool.ToolKind:
|
||||
actual := httptool.Config{Name: name}
|
||||
if err := dec.DecodeContext(ctx, &actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of tool", kind)
|
||||
}
|
||||
|
||||
91
internal/sources/http/http.go
Normal file
91
internal/sources/http/http.go
Normal file
@@ -0,0 +1,91 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "http"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
BaseURL string `yaml:"baseUrl"`
|
||||
Timeout string `yaml:"timeout"`
|
||||
DefaultHeaders map[string]string `yaml:"headers"`
|
||||
QueryParams map[string]string `yaml:"queryParams"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
// DefaultConfig is a helper function that generates the default configuration for an HTTP Tool Config.
|
||||
func DefaultConfig(name string) Config {
|
||||
return Config{Name: name, Timeout: "30s"}
|
||||
}
|
||||
|
||||
// Initialize initializes an HTTP Source instance.
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
duration, err := time.ParseDuration(r.Timeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse Timeout string as time.Duration: %s", err)
|
||||
}
|
||||
client := http.Client{
|
||||
Timeout: duration,
|
||||
}
|
||||
|
||||
// Validate BaseURL
|
||||
_, err = url.ParseRequestURI(r.BaseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse BaseUrl %v", err)
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
BaseURL: r.BaseURL,
|
||||
DefaultHeaders: r.DefaultHeaders,
|
||||
QueryParams: r.QueryParams,
|
||||
Client: &client,
|
||||
}
|
||||
return s, nil
|
||||
|
||||
}
|
||||
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
BaseURL string `yaml:"baseUrl"`
|
||||
DefaultHeaders map[string]string `yaml:"headers"`
|
||||
QueryParams map[string]string `yaml:"queryParams"`
|
||||
Client *http.Client
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
128
internal/sources/http/http_test.go
Normal file
128
internal/sources/http/http_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package http_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/http"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlHttp(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
sources:
|
||||
my-http-instance:
|
||||
kind: http
|
||||
baseUrl: http://test_server/
|
||||
timeout: 10s
|
||||
headers:
|
||||
Authorization: test_header
|
||||
Custom-Header: custom
|
||||
queryParams:
|
||||
api-key: test_api_key
|
||||
param: param-value
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-http-instance": http.Config{
|
||||
Name: "my-http-instance",
|
||||
Kind: http.SourceKind,
|
||||
BaseURL: "http://test_server/",
|
||||
Timeout: "10s",
|
||||
DefaultHeaders: map[string]string{"Authorization": "test_header", "Custom-Header": "custom"},
|
||||
QueryParams: map[string]string{"api-key": "test_api_key", "param": "param-value"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailParseFromYaml(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
sources:
|
||||
my-http-instance:
|
||||
kind: http
|
||||
baseUrl: http://test_server/
|
||||
timeout: 10s
|
||||
headers:
|
||||
Authorization: test_header
|
||||
queryParams:
|
||||
api-key: test_api_key
|
||||
project: test-project
|
||||
`,
|
||||
err: "unable to parse as \"http\"",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
sources:
|
||||
my-http-instance:
|
||||
baseUrl: http://test_server/
|
||||
`,
|
||||
err: "missing 'kind' field for \"my-http-instance\"",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
|
||||
if !strings.Contains(errStr, tc.err) {
|
||||
t.Fatalf("unexpected error string: got %q, want substring %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
276
internal/tools/http/http.go
Normal file
276
internal/tools/http/http.go
Normal file
@@ -0,0 +1,276 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"maps"
|
||||
"text/template"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const ToolKind string = "http"
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Path string `yaml:"path" validate:"required"`
|
||||
Method tools.HTTPMethod `yaml:"method" validate:"required"`
|
||||
Headers map[string]string `yaml:"headers"`
|
||||
RequestBody string `yaml:"requestBody"`
|
||||
QueryParams tools.Parameters `yaml:"queryParams"`
|
||||
BodyParams tools.Parameters `yaml:"bodyParams"`
|
||||
HeaderParams tools.Parameters `yaml:"headerParams"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return ToolKind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*httpsrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `http`", ToolKind)
|
||||
}
|
||||
|
||||
// Create URL based on BaseURL and Path
|
||||
// Attach query parameters
|
||||
u, err := url.Parse(s.BaseURL + cfg.Path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing URL: %s", err)
|
||||
}
|
||||
|
||||
// Get existing query parameters from the URL
|
||||
queryParameters := u.Query()
|
||||
for key, value := range s.QueryParams {
|
||||
queryParameters.Add(key, value)
|
||||
}
|
||||
u.RawQuery = queryParameters.Encode()
|
||||
|
||||
// Combine Source and Tool headers.
|
||||
// In case of conflict, Tool header overrides Source header
|
||||
combinedHeaders := make(map[string]string)
|
||||
maps.Copy(combinedHeaders, s.DefaultHeaders)
|
||||
maps.Copy(combinedHeaders, cfg.Headers)
|
||||
|
||||
// Create a slice for all parameters
|
||||
allParameters := slices.Concat(cfg.BodyParams, cfg.HeaderParams, cfg.QueryParams)
|
||||
|
||||
// Create parameter manifest
|
||||
paramManifest := slices.Concat(
|
||||
cfg.QueryParams.Manifest(),
|
||||
cfg.BodyParams.Manifest(),
|
||||
cfg.HeaderParams.Manifest(),
|
||||
)
|
||||
if paramManifest == nil {
|
||||
paramManifest = make([]tools.ParameterManifest, 0)
|
||||
}
|
||||
|
||||
// Verify there are no duplicate parameter names
|
||||
seenNames := make(map[string]bool)
|
||||
for _, param := range paramManifest {
|
||||
if _, exists := seenNames[param.Name]; exists {
|
||||
return nil, fmt.Errorf("parameter name must be unique across queryParams, bodyParams, and headerParams. Duplicate parameter: %s", param.Name)
|
||||
}
|
||||
seenNames[param.Name] = true
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: ToolKind,
|
||||
URL: u,
|
||||
Method: cfg.Method,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
RequestBody: cfg.RequestBody,
|
||||
QueryParams: cfg.QueryParams,
|
||||
BodyParams: cfg.BodyParams,
|
||||
HeaderParams: cfg.HeaderParams,
|
||||
Headers: combinedHeaders,
|
||||
Client: s.Client,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
|
||||
URL *url.URL `yaml:"url"`
|
||||
Method tools.HTTPMethod `yaml:"method"`
|
||||
Headers map[string]string `yaml:"headers"`
|
||||
RequestBody string `yaml:"requestBody"`
|
||||
QueryParams tools.Parameters `yaml:"queryParams"`
|
||||
BodyParams tools.Parameters `yaml:"bodyParams"`
|
||||
HeaderParams tools.Parameters `yaml:"headerParams"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
Client *http.Client
|
||||
manifest tools.Manifest
|
||||
}
|
||||
|
||||
// helper function to convert a parameter to JSON formatted string.
|
||||
func convertParamToJSON(param any) (string, error) {
|
||||
jsonData, err := json.Marshal(param)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal param to JSON: %w", err)
|
||||
}
|
||||
return string(jsonData), nil
|
||||
}
|
||||
|
||||
// Helper function to generate the HTTP request body upon Tool invocation.
|
||||
func getRequestBody(bodyParams tools.Parameters, requestBodyPayload string, paramsMap map[string]any) (string, error) {
|
||||
// Create a map for request body parameters
|
||||
bodyParamsMap := make(map[string]any)
|
||||
for _, p := range bodyParams {
|
||||
k := p.GetName()
|
||||
v, ok := paramsMap[k]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing request body parameter %s", k)
|
||||
}
|
||||
bodyParamsMap[k] = v
|
||||
}
|
||||
|
||||
// Create a FuncMap to format array parameters
|
||||
funcMap := template.FuncMap{
|
||||
"json": convertParamToJSON,
|
||||
}
|
||||
templ, err := template.New("body").Funcs(funcMap).Parse(requestBodyPayload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error parsing request body: %s", err)
|
||||
}
|
||||
var result bytes.Buffer
|
||||
err = templ.Execute(&result, bodyParamsMap)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error replacing body payload: %s", err)
|
||||
}
|
||||
return result.String(), nil
|
||||
}
|
||||
|
||||
// Helper function to generate the HTTP request URL upon Tool invocation.
|
||||
func getURL(u *url.URL, queryParams tools.Parameters, paramsMap map[string]any) (string, error) {
|
||||
// Set dynamic query parameters
|
||||
query := u.Query()
|
||||
for _, p := range queryParams {
|
||||
query.Add(p.GetName(), fmt.Sprintf("%v", paramsMap[p.GetName()]))
|
||||
}
|
||||
u.RawQuery = query.Encode()
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// Helper function to generate the HTTP headers upon Tool invocation.
|
||||
func getHeaders(headerParams tools.Parameters, defaultHeaders map[string]string, paramsMap map[string]any) (map[string]string, error) {
|
||||
// Populate header params
|
||||
allHeaders := make(map[string]string)
|
||||
maps.Copy(allHeaders, defaultHeaders)
|
||||
for _, p := range headerParams {
|
||||
headerValue, ok := paramsMap[p.GetName()]
|
||||
if ok {
|
||||
if strValue, ok := headerValue.(string); ok {
|
||||
allHeaders[p.GetName()] = strValue
|
||||
} else {
|
||||
return nil, fmt.Errorf("header param %s got value of type %t, not string", p.GetName(), headerValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
return allHeaders, nil
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
// Calculate request body
|
||||
requestBody, err := getRequestBody(t.BodyParams, t.RequestBody, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating request body: %s", err)
|
||||
}
|
||||
|
||||
// Calculate URL
|
||||
urlString, err := getURL(t.URL, t.QueryParams, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating query parameters: %s", err)
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest(string(t.Method), urlString, strings.NewReader(requestBody))
|
||||
|
||||
// Calculate request headers
|
||||
allHeaders, err := getHeaders(t.HeaderParams, t.Headers, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating request headers: %s", err)
|
||||
}
|
||||
// Set request headers
|
||||
for k, v := range allHeaders {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
// Make request and fetch response
|
||||
resp, err := t.Client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making HTTP request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var body []byte
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return []any{string(body)}, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
199
internal/tools/http/http_test.go
Normal file
199
internal/tools/http/http_test.go
Normal file
@@ -0,0 +1,199 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package http_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
http "github.com/googleapis/genai-toolbox/internal/tools/http"
|
||||
)
|
||||
|
||||
func TestParseFromYamlHTTP(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: http
|
||||
source: my-instance
|
||||
method: GET
|
||||
path: "search?name=alice&pet=cat"
|
||||
description: some description
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
queryParams:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
authServices:
|
||||
- name: my-google-auth-service
|
||||
field: user_id
|
||||
- name: other-auth-service
|
||||
field: user_id
|
||||
requestBody: |
|
||||
{
|
||||
"age": {{.age}}
|
||||
"city": "{{.city}}"
|
||||
"food": {{.food}}
|
||||
}
|
||||
bodyParams:
|
||||
- name: age
|
||||
type: integer
|
||||
description: age num
|
||||
- name: city
|
||||
type: string
|
||||
description: city string
|
||||
headers:
|
||||
Authorization: API_KEY
|
||||
Content-Type: application/json
|
||||
headerParams:
|
||||
- name: Language
|
||||
type: string
|
||||
description: language string
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": http.Config{
|
||||
Name: "example_tool",
|
||||
Kind: http.ToolKind,
|
||||
Source: "my-instance",
|
||||
Method: "GET",
|
||||
Path: "search?name=alice&pet=cat",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
QueryParams: []tools.Parameter{
|
||||
tools.NewStringParameterWithAuth("country", "some description",
|
||||
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"},
|
||||
{Name: "other-auth-service", Field: "user_id"}}),
|
||||
},
|
||||
RequestBody: `{
|
||||
"age": {{.age}}
|
||||
"city": "{{.city}}"
|
||||
"food": {{.food}}
|
||||
}
|
||||
`,
|
||||
BodyParams: []tools.Parameter{tools.NewIntParameter("age", "age num"), tools.NewStringParameter("city", "city string")},
|
||||
Headers: map[string]string{"Authorization": "API_KEY", "Content-Type": "application/json"},
|
||||
HeaderParams: []tools.Parameter{tools.NewStringParameter("Language", "language string")},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFailParseFromYamlHTTP(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "Invalid method",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: http
|
||||
source: my-instance
|
||||
method: GOT
|
||||
path: "search?name=alice&pet=cat"
|
||||
description: some description
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
queryParams:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
authServices:
|
||||
- name: my-google-auth-service
|
||||
field: user_id
|
||||
- name: other-auth-service
|
||||
field: user_id
|
||||
requestBody: |
|
||||
{
|
||||
"age": {{.age}},
|
||||
"city": "{{.city}}",
|
||||
}
|
||||
bodyParams:
|
||||
- name: age
|
||||
type: integer
|
||||
description: age num
|
||||
- name: city
|
||||
type: string
|
||||
description: city string
|
||||
headers:
|
||||
Authorization: API_KEY
|
||||
Content-Type: application/json
|
||||
headerParams:
|
||||
- name: Language
|
||||
type: string
|
||||
description: language string
|
||||
`,
|
||||
err: `GOT is not a valid http method`,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if !strings.Contains(errStr, tc.err) {
|
||||
t.Fatalf("unexpected error string: got %q, want substring %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
49
internal/tools/http_method.go
Normal file
49
internal/tools/http_method.go
Normal file
@@ -0,0 +1,49 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// HTTPMethod is a string of a valid HTTP method (e.g "GET")
|
||||
type HTTPMethod string
|
||||
|
||||
// isValidHTTPMethod checks if the input string matches one of the method constants defined in the net/http package
|
||||
func isValidHTTPMethod(method string) bool {
|
||||
|
||||
switch method {
|
||||
case http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete,
|
||||
http.MethodPatch, http.MethodHead, http.MethodOptions, http.MethodTrace,
|
||||
http.MethodConnect:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *HTTPMethod) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
var httpMethod string
|
||||
if err := unmarshal(&httpMethod); err != nil {
|
||||
return fmt.Errorf(`error unmarshalling HTTP method: %s`, err)
|
||||
}
|
||||
httpMethod = strings.ToUpper(httpMethod)
|
||||
if !isValidHTTPMethod(httpMethod) {
|
||||
return fmt.Errorf(`%s is not a valid http method`, httpMethod)
|
||||
}
|
||||
*i = HTTPMethod(httpMethod)
|
||||
return nil
|
||||
}
|
||||
@@ -122,7 +122,7 @@ func ParseParams(ps Parameters, data map[string]any, claimsMap map[string]map[st
|
||||
var v any
|
||||
paramAuthServices := p.GetAuthServices()
|
||||
name := p.GetName()
|
||||
if paramAuthServices == nil {
|
||||
if len(paramAuthServices) == 0 {
|
||||
// parse non auth-required parameter
|
||||
var ok bool
|
||||
v, ok = data[name]
|
||||
@@ -567,8 +567,8 @@ func (p *ArrayParameter) UnmarshalYAML(ctx context.Context, unmarshal func(inter
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse 'items' field: %w", err)
|
||||
}
|
||||
if i.GetAuthServices() != nil {
|
||||
return fmt.Errorf("nested items should not have auth services.")
|
||||
if i.GetAuthServices() != nil && len(i.GetAuthServices()) != 0 {
|
||||
return fmt.Errorf("nested items should not have auth services")
|
||||
}
|
||||
p.Items = i
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ package tests
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
// GetToolsConfig returns a mock tools config file
|
||||
@@ -97,6 +99,97 @@ func GetToolsConfig(sourceConfig map[string]any, toolKind, param_tool_statement,
|
||||
return toolsFile
|
||||
}
|
||||
|
||||
// GetHTTPToolsConfig returns a mock HTTP tool's config file
|
||||
func GetHTTPToolsConfig(sourceConfig map[string]any, toolKind string) map[string]any {
|
||||
// Write config into a file and pass it to command
|
||||
otherSourceConfig := make(map[string]any)
|
||||
for k, v := range sourceConfig {
|
||||
otherSourceConfig[k] = v
|
||||
}
|
||||
otherSourceConfig["headers"] = map[string]string{"X-Custom-Header": "unexpected", "Content-Type": "application/json"}
|
||||
otherSourceConfig["queryParams"] = map[string]any{"id": 1, "name": "Sid"}
|
||||
|
||||
toolsFile := map[string]any{
|
||||
"sources": map[string]any{
|
||||
"my-instance": sourceConfig,
|
||||
"other-instance": otherSourceConfig,
|
||||
},
|
||||
"authServices": map[string]any{
|
||||
"my-google-auth": map[string]any{
|
||||
"kind": "google",
|
||||
"clientId": ClientId,
|
||||
},
|
||||
},
|
||||
"tools": map[string]any{
|
||||
"my-simple-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"path": "/tool0",
|
||||
"method": "POST",
|
||||
"source": "my-instance",
|
||||
"requestBody": "{}",
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
},
|
||||
"my-param-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"method": "GET",
|
||||
"path": "/tool1",
|
||||
"description": "some description",
|
||||
"queryParams": []tools.Parameter{
|
||||
tools.NewIntParameter("id", "user ID")},
|
||||
"requestBody": `{
|
||||
"age": 36,
|
||||
"name": "{{.name}}"
|
||||
}
|
||||
`,
|
||||
"bodyParams": []tools.Parameter{tools.NewStringParameter("name", "user name")},
|
||||
"headers": map[string]string{"Content-Type": "application/json"},
|
||||
},
|
||||
"my-auth-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"method": "GET",
|
||||
"path": "/tool2",
|
||||
"description": "some description",
|
||||
"requestBody": "{}",
|
||||
"queryParams": []tools.Parameter{
|
||||
tools.NewStringParameterWithAuth("email", "some description",
|
||||
[]tools.ParamAuthService{{Name: "my-google-auth", Field: "email"}}),
|
||||
},
|
||||
},
|
||||
"my-auth-required-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"method": "POST",
|
||||
"path": "/tool0",
|
||||
"description": "some description",
|
||||
"requestBody": "{}",
|
||||
"authRequired": []string{"my-google-auth"},
|
||||
},
|
||||
"my-advanced-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "other-instance",
|
||||
"method": "get",
|
||||
"path": "/tool3?id=2",
|
||||
"description": "some description",
|
||||
"headers": map[string]string{
|
||||
"X-Custom-Header": "example",
|
||||
},
|
||||
"queryParams": []tools.Parameter{
|
||||
tools.NewIntParameter("id", "user ID"), tools.NewStringParameter("country", "country")},
|
||||
"requestBody": `{
|
||||
"place": "zoo",
|
||||
"animals": {{json .animalArray }}
|
||||
}
|
||||
`,
|
||||
"bodyParams": []tools.Parameter{tools.NewArrayParameter("animalArray", "animals in the zoo", tools.NewStringParameter("animals", "desc"))},
|
||||
"headerParams": []tools.Parameter{tools.NewStringParameter("X-Other-Header", "custom header")},
|
||||
},
|
||||
},
|
||||
}
|
||||
return toolsFile
|
||||
}
|
||||
|
||||
// GetPostgresSQLParamToolInfo returns statements and param for my-param-tool postgres-sql kind
|
||||
func GetPostgresSQLParamToolInfo(tableName string) (string, string, string, []any) {
|
||||
create_statement := fmt.Sprintf("CREATE TABLE %s (id SERIAL PRIMARY KEY, name TEXT);", tableName)
|
||||
|
||||
335
tests/http_test.go
Normal file
335
tests/http_test.go
Normal file
@@ -0,0 +1,335 @@
|
||||
//go:build integration && http
|
||||
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
HTTP_SOURCE_KIND = "http"
|
||||
HTTP_TOOL_KIND = "http"
|
||||
)
|
||||
|
||||
func getHTTPSourceConfig(t *testing.T) map[string]any {
|
||||
idToken, err := GetGoogleIdToken(ClientId)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting ID token: %s", err)
|
||||
}
|
||||
idToken = "Bearer " + idToken
|
||||
return map[string]any{
|
||||
"kind": HTTP_SOURCE_KIND,
|
||||
"headers": map[string]string{"Authorization": idToken},
|
||||
}
|
||||
}
|
||||
|
||||
// handler function for the test server
|
||||
func multiTool(w http.ResponseWriter, r *http.Request) {
|
||||
path := r.URL.Path
|
||||
path = strings.TrimPrefix(path, "/") // Remove leading slash
|
||||
|
||||
switch path {
|
||||
case "tool0":
|
||||
handleTool0(w, r)
|
||||
case "tool1":
|
||||
handleTool1(w, r)
|
||||
case "tool2":
|
||||
handleTool2(w, r)
|
||||
case "tool3":
|
||||
handleTool3(w, r)
|
||||
default:
|
||||
http.NotFound(w, r) // Return 404 for unknown paths
|
||||
}
|
||||
}
|
||||
|
||||
// handler function for the test server
|
||||
func handleTool0(w http.ResponseWriter, r *http.Request) {
|
||||
// expect POST method
|
||||
if r.Method != http.MethodPost {
|
||||
errorMessage := fmt.Sprintf("expected POST method but got: %s", string(r.Method))
|
||||
http.Error(w, errorMessage, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
response := []string{
|
||||
"Hello",
|
||||
"World",
|
||||
}
|
||||
err := json.NewEncoder(w).Encode(response)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to encode JSON", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// handler function for the test server
|
||||
func handleTool1(w http.ResponseWriter, r *http.Request) {
|
||||
// expect GET method
|
||||
if r.Method != http.MethodGet {
|
||||
errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method))
|
||||
http.Error(w, errorMessage, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
// Parse request body
|
||||
var requestBody map[string]interface{}
|
||||
bodyBytes, readErr := io.ReadAll(r.Body)
|
||||
if readErr != nil {
|
||||
http.Error(w, "Bad Request: Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
err := json.Unmarshal(bodyBytes, &requestBody)
|
||||
if err != nil {
|
||||
errorMessage := fmt.Sprintf("Bad Request: Error unmarshalling request body: %s, Raw body: %s", err, string(bodyBytes))
|
||||
http.Error(w, errorMessage, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract name
|
||||
name, ok := requestBody["name"].(string)
|
||||
if !ok || name == "" {
|
||||
http.Error(w, "Bad Request: Missing or invalid name", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if name == "Alice" {
|
||||
response := `{"id":1,"name":"Alice"},{"id":3,"name":"Sid"}`
|
||||
_, err := w.Write([]byte(response))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to write response", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
// handler function for the test server
|
||||
func handleTool2(w http.ResponseWriter, r *http.Request) {
|
||||
// expect GET method
|
||||
if r.Method != http.MethodGet {
|
||||
errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method))
|
||||
http.Error(w, errorMessage, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
email := r.URL.Query().Get("email")
|
||||
if email != "" {
|
||||
response := `{"name":"Alice"}`
|
||||
_, err := w.Write([]byte(response))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to write response", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
// handler function for the test server
|
||||
func handleTool3(w http.ResponseWriter, r *http.Request) {
|
||||
// expect GET method
|
||||
if r.Method != http.MethodGet {
|
||||
errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method))
|
||||
http.Error(w, errorMessage, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Check request headers
|
||||
expectedHeaders := map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"X-Custom-Header": "example",
|
||||
"X-Other-Header": "test",
|
||||
}
|
||||
for header, expectedValue := range expectedHeaders {
|
||||
if r.Header.Get(header) != expectedValue {
|
||||
errorMessage := fmt.Sprintf("Bad Request: Missing or incorrect header: %s", header)
|
||||
http.Error(w, errorMessage, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check query parameters
|
||||
expectedQueryParams := map[string][]string{
|
||||
"id": []string{"2", "1", "3"},
|
||||
"country": []string{"US"},
|
||||
}
|
||||
query := r.URL.Query()
|
||||
for param, expectedValueSlice := range expectedQueryParams {
|
||||
values, ok := query[param]
|
||||
if ok {
|
||||
if !reflect.DeepEqual(expectedValueSlice, values) {
|
||||
errorMessage := fmt.Sprintf("Bad Request: Incorrect query parameter: %s, actual: %s", param, query[param])
|
||||
http.Error(w, errorMessage, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
errorMessage := fmt.Sprintf("Bad Request: Missing query parameter: %s, actual: %s", param, query[param])
|
||||
http.Error(w, errorMessage, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Parse request body
|
||||
var requestBody map[string]interface{}
|
||||
bodyBytes, readErr := io.ReadAll(r.Body)
|
||||
if readErr != nil {
|
||||
http.Error(w, "Bad Request: Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
err := json.Unmarshal(bodyBytes, &requestBody)
|
||||
if err != nil {
|
||||
errorMessage := fmt.Sprintf("Bad Request: Error unmarshalling request body: %s, Raw body: %s", err, string(bodyBytes))
|
||||
http.Error(w, errorMessage, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Check request body
|
||||
expectedBody := map[string]interface{}{
|
||||
"place": "zoo",
|
||||
"animals": []any{"rabbit", "ostrich", "whale"},
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(requestBody, expectedBody) {
|
||||
errorMessage := fmt.Sprintf("Bad Request: Incorrect request body. Expected: %v, Got: %v", expectedBody, requestBody)
|
||||
http.Error(w, errorMessage, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Return a JSON array as the response
|
||||
response := []any{
|
||||
"Hello", "World",
|
||||
}
|
||||
err = json.NewEncoder(w).Encode(response)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to encode JSON", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolEndpoints(t *testing.T) {
|
||||
// start a test server
|
||||
server := httptest.NewServer(http.HandlerFunc(multiTool))
|
||||
defer server.Close()
|
||||
|
||||
sourceConfig := getHTTPSourceConfig(t)
|
||||
sourceConfig["baseUrl"] = server.URL
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
|
||||
toolsFile := GetHTTPToolsConfig(sourceConfig, HTTP_TOOL_KIND)
|
||||
cmd, cleanup, err := StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
select_1_want := `["[\"Hello\",\"World\"]\n"]`
|
||||
RunToolGetTest(t)
|
||||
RunToolInvokeTest(t, select_1_want)
|
||||
RunAdvancedHTTPInvokeTest(t)
|
||||
}
|
||||
|
||||
// RunToolInvoke runs the tool invoke endpoint
|
||||
func RunAdvancedHTTPInvokeTest(t *testing.T) {
|
||||
// Test HTTP tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
isErr bool
|
||||
}{
|
||||
{
|
||||
name: "invoke my-advanced-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-advanced-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"animalArray": ["rabbit", "ostrich", "whale"], "id": 3, "country": "US", "X-Other-Header": "test"}`)),
|
||||
want: `["[\"Hello\",\"World\"]\n"]`,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-advanced-tool with wrong params",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-advanced-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"animalArray": ["rabbit", "ostrich", "whale"], "id": 4, "country": "US", "X-Other-Header": "test"}`)),
|
||||
isErr: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr == true {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
@@ -287,12 +288,17 @@ func RunToolInvokeTest(t *testing.T, select_1_want string) {
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
// Remove `\` and `"` for string comparison
|
||||
got = strings.ReplaceAll(got, "\\", "")
|
||||
want := strings.ReplaceAll(tc.want, "\\", "")
|
||||
got = strings.ReplaceAll(got, "\"", "")
|
||||
want = strings.ReplaceAll(want, "\"", "")
|
||||
if got != want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user