mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 23:48:04 -05:00
fix: update tool invoke to return json (#266)
Return actual rows as `[]any` that contains `map` of results. Each `map` represent a row, with the key being column name.
This commit is contained in:
@@ -40,3 +40,6 @@ run:
|
|||||||
- cloudsqlmssql
|
- cloudsqlmssql
|
||||||
- cloudsqlmysql
|
- cloudsqlmysql
|
||||||
- neo4j
|
- neo4j
|
||||||
|
- dgraph
|
||||||
|
- mssql
|
||||||
|
- mysql
|
||||||
|
|||||||
@@ -302,7 +302,7 @@ func TestParseToolFile(t *testing.T) {
|
|||||||
Instance: "my-instance",
|
Instance: "my-instance",
|
||||||
IPType: "public",
|
IPType: "public",
|
||||||
Database: "my_db",
|
Database: "my_db",
|
||||||
User: "my_user",
|
User: "my_user",
|
||||||
Password: "my_pass",
|
Password: "my_pass",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -416,7 +416,7 @@ func TestParseToolFileWithAuth(t *testing.T) {
|
|||||||
Instance: "my-instance",
|
Instance: "my-instance",
|
||||||
IPType: "public",
|
IPType: "public",
|
||||||
Database: "my_db",
|
Database: "my_db",
|
||||||
User: "my_user",
|
User: "my_user",
|
||||||
Password: "my_pass",
|
Password: "my_pass",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -220,7 +220,15 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = render.Render(w, r, &resultResponse{Result: res})
|
resMarshal, err := json.Marshal(res)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("unable to marshal result: %w", err)
|
||||||
|
s.logger.DebugContext(context.Background(), err.Error())
|
||||||
|
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = render.Render(w, r, &resultResponse{Result: string(resMarshal)})
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ render.Renderer = &resultResponse{} // Renderer interface for managing response payloads.
|
var _ render.Renderer = &resultResponse{} // Renderer interface for managing response payloads.
|
||||||
|
|||||||
@@ -39,8 +39,9 @@ type MockTool struct {
|
|||||||
Params []tools.Parameter
|
Params []tools.Parameter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t MockTool) Invoke(tools.ParamValues) (string, error) {
|
func (t MockTool) Invoke(tools.ParamValues) ([]any, error) {
|
||||||
return "", nil
|
mock := make([]any, 0)
|
||||||
|
return mock, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// claims is a map of user info decoded from an auth token
|
// claims is a map of user info decoded from an auth token
|
||||||
|
|||||||
@@ -94,30 +94,29 @@ type Tool struct {
|
|||||||
manifest tools.Manifest
|
manifest tools.Manifest
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
|
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
|
||||||
paramsMap := params.AsMapWithDollarPrefix()
|
paramsMap := params.AsMapWithDollarPrefix()
|
||||||
|
|
||||||
resp, err := t.DgraphClient.ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout)
|
resp, err := t.DgraphClient.ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := dgraph.CheckError(resp); err != nil {
|
if err := dgraph.CheckError(resp); err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var out []any
|
||||||
var result struct {
|
var result struct {
|
||||||
Data map[string]interface{} `json:"data"`
|
Data map[string]interface{} `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(resp, &result); err != nil {
|
if err := json.Unmarshal(resp, &result); err != nil {
|
||||||
return "", fmt.Errorf("error parsing JSON: %v", err)
|
return nil, fmt.Errorf("error parsing JSON: %v", err)
|
||||||
}
|
}
|
||||||
|
out = append(out, result.Data)
|
||||||
|
|
||||||
return fmt.Sprintf(
|
return out, nil
|
||||||
"Stub tool call for %q! Parameters parsed: %q \n Output: %v",
|
|
||||||
t.Name, paramsMap, result.Data,
|
|
||||||
), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) {
|
||||||
|
|||||||
@@ -107,9 +107,7 @@ type Tool struct {
|
|||||||
manifest tools.Manifest
|
manifest tools.Manifest
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
|
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
|
||||||
fmt.Printf("Invoked tool %s\n", t.Name)
|
|
||||||
|
|
||||||
namedArgs := make([]any, 0, len(params))
|
namedArgs := make([]any, 0, len(params))
|
||||||
paramsMap := params.AsReversedMap()
|
paramsMap := params.AsReversedMap()
|
||||||
// To support both named args (e.g @id) and positional args (e.g @p1), check if arg name is contained in the statement.
|
// To support both named args (e.g @id) and positional args (e.g @p1), check if arg name is contained in the statement.
|
||||||
@@ -123,39 +121,44 @@ func (t Tool) Invoke(params tools.ParamValues) (string, error) {
|
|||||||
}
|
}
|
||||||
rows, err := t.Db.QueryContext(context.Background(), t.Statement, namedArgs...)
|
rows, err := t.Db.QueryContext(context.Background(), t.Statement, namedArgs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("unable to execute query: %w", err)
|
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
types, err := rows.ColumnTypes()
|
cols, err := rows.Columns()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("unable to fetch column types: %w", err)
|
return nil, fmt.Errorf("unable to fetch column types: %w", err)
|
||||||
}
|
|
||||||
v := make([]any, len(types))
|
|
||||||
pointers := make([]any, len(types))
|
|
||||||
for i := range types {
|
|
||||||
pointers[i] = &v[i]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// fetch result into a string
|
// create an array of values for each column, which can be re-used to scan each row
|
||||||
var out strings.Builder
|
rawValues := make([]any, len(cols))
|
||||||
|
values := make([]any, len(cols))
|
||||||
|
for i := range rawValues {
|
||||||
|
values[i] = &rawValues[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
var out []any
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
err = rows.Scan(pointers...)
|
err = rows.Scan(values...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("unable to parse row: %w", err)
|
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||||
}
|
}
|
||||||
out.WriteString(fmt.Sprintf("%s", v))
|
vMap := make(map[string]any)
|
||||||
|
for i, name := range cols {
|
||||||
|
vMap[name] = rawValues[i]
|
||||||
|
}
|
||||||
|
out = append(out, vMap)
|
||||||
}
|
}
|
||||||
err = rows.Close()
|
err = rows.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("unable to close rows: %w", err)
|
return nil, fmt.Errorf("unable to close rows: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if error occured during iteration
|
// Check if error occured during iteration
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q \n Output: %s", t.Name, params, out.String()), nil
|
|
||||||
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
|
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
|
||||||
@@ -107,44 +106,54 @@ type Tool struct {
|
|||||||
manifest tools.Manifest
|
manifest tools.Manifest
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
|
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
|
||||||
sliceParams := params.AsSlice()
|
sliceParams := params.AsSlice()
|
||||||
|
|
||||||
results, err := t.Pool.QueryContext(context.Background(), t.Statement, sliceParams...)
|
results, err := t.Pool.QueryContext(context.Background(), t.Statement, sliceParams...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("unable to execute query: %w", err)
|
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cols, err := results.Columns()
|
cols, err := results.Columns()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("unable to retrieve rows column name: %w", err)
|
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cl := len(cols)
|
// create an array of values for each column, which can be re-used to scan each row
|
||||||
v := make([]any, cl)
|
rawValues := make([]any, len(cols))
|
||||||
pointers := make([]any, cl)
|
values := make([]any, len(cols))
|
||||||
for i := range v {
|
for i := range rawValues {
|
||||||
pointers[i] = &v[i]
|
values[i] = &rawValues[i]
|
||||||
}
|
}
|
||||||
var out strings.Builder
|
|
||||||
|
var out []any
|
||||||
for results.Next() {
|
for results.Next() {
|
||||||
err := results.Scan(pointers...)
|
err := results.Scan(values...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("unable to parse row: %w", err)
|
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||||
}
|
}
|
||||||
out.WriteString(fmt.Sprintf("%s", v))
|
vMap := make(map[string]any)
|
||||||
|
for i, name := range cols {
|
||||||
|
b, ok := rawValues[i].([]byte)
|
||||||
|
if ok {
|
||||||
|
vMap[name] = string(b)
|
||||||
|
} else {
|
||||||
|
vMap[name] = rawValues[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out = append(out, vMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = results.Close()
|
err = results.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("unable to close rows: %w", err)
|
return nil, fmt.Errorf("unable to close rows: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := results.Err(); err != nil {
|
if err := results.Err(); err != nil {
|
||||||
return "", fmt.Errorf("errors encountered by results.Scan: %w", err)
|
return nil, fmt.Errorf("errors encountered by results.Scan: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q \n Output: %s", t.Name, params, out.String()), nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ package neo4j
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
|
|
||||||
neo4jsc "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
|
neo4jsc "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
|
||||||
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
|
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
|
||||||
@@ -95,28 +94,29 @@ type Tool struct {
|
|||||||
manifest tools.Manifest
|
manifest tools.Manifest
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
|
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
|
|
||||||
fmt.Printf("Invoked tool %s\n", t.Name)
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
config := neo4j.ExecuteQueryWithDatabase(t.Database)
|
config := neo4j.ExecuteQueryWithDatabase(t.Database)
|
||||||
results, err := neo4j.ExecuteQuery[*neo4j.EagerResult](ctx, t.Driver, t.Statement, paramsMap,
|
results, err := neo4j.ExecuteQuery[*neo4j.EagerResult](ctx, t.Driver, t.Statement, paramsMap,
|
||||||
neo4j.EagerResultTransformer, config)
|
neo4j.EagerResultTransformer, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("unable to execute query: %w", err)
|
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var out strings.Builder
|
var out []any
|
||||||
keys := results.Keys
|
keys := results.Keys
|
||||||
records := results.Records
|
records := results.Records
|
||||||
for _, record := range records {
|
for _, record := range records {
|
||||||
out.WriteString("\n") // fmt.Sprintf("Row: %d\n", row))
|
vMap := make(map[string]any)
|
||||||
for col, value := range record.Values {
|
for col, value := range record.Values {
|
||||||
out.WriteString(fmt.Sprintf("\t%s: %s\n", keys[col], value))
|
vMap[keys[col]] = value
|
||||||
}
|
}
|
||||||
|
out = append(out, vMap)
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q \n Output: %s", t.Name, paramsMap, out.String()), nil
|
|
||||||
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) {
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ package postgressql
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||||
@@ -109,23 +108,29 @@ type Tool struct {
|
|||||||
manifest tools.Manifest
|
manifest tools.Manifest
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
|
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
|
||||||
sliceParams := params.AsSlice()
|
sliceParams := params.AsSlice()
|
||||||
results, err := t.Pool.Query(context.Background(), t.Statement, sliceParams...)
|
results, err := t.Pool.Query(context.Background(), t.Statement, sliceParams...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("unable to execute query: %w", err)
|
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var out strings.Builder
|
fields := results.FieldDescriptions()
|
||||||
|
|
||||||
|
var out []any
|
||||||
for results.Next() {
|
for results.Next() {
|
||||||
v, err := results.Values()
|
v, err := results.Values()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("unable to parse row: %w", err)
|
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||||
}
|
}
|
||||||
out.WriteString(fmt.Sprintf("%s", v))
|
vMap := make(map[string]any)
|
||||||
|
for i, f := range fields {
|
||||||
|
vMap[f.Name] = v[i]
|
||||||
|
}
|
||||||
|
out = append(out, vMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q \n Output: %s", t.Name, params, out.String()), nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||||
|
|||||||
@@ -121,13 +121,13 @@ func getMapParams(params tools.ParamValues, dialect string) (map[string]interfac
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
|
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
|
||||||
mapParams, err := getMapParams(params, t.dialect)
|
mapParams, err := getMapParams(params, t.dialect)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("fail to get map params: %w", err)
|
return nil, fmt.Errorf("fail to get map params: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var out strings.Builder
|
var out []any
|
||||||
|
|
||||||
_, err = t.Client.ReadWriteTransaction(context.Background(), func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
|
_, err = t.Client.ReadWriteTransaction(context.Background(), func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
|
||||||
stmt := spanner.Statement{
|
stmt := spanner.Statement{
|
||||||
@@ -145,14 +145,21 @@ func (t Tool) Invoke(params tools.ParamValues) (string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to parse row: %w", err)
|
return fmt.Errorf("unable to parse row: %w", err)
|
||||||
}
|
}
|
||||||
out.WriteString(row.String())
|
|
||||||
|
vMap := make(map[string]any)
|
||||||
|
cols := row.ColumnNames()
|
||||||
|
for i, c := range cols {
|
||||||
|
vMap[c] = row.ColumnValue(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, vMap)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("unable to execute client: %w", err)
|
return nil, fmt.Errorf("unable to execute client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q \n Output: %s", t.Name, params, out.String()), nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ type ToolConfig interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Tool interface {
|
type Tool interface {
|
||||||
Invoke(ParamValues) (string, error)
|
Invoke(ParamValues) ([]any, error)
|
||||||
ParseParams(map[string]any, map[string]map[string]any) (ParamValues, error)
|
ParseParams(map[string]any, map[string]map[string]any) (ParamValues, error)
|
||||||
Manifest() Manifest
|
Manifest() Manifest
|
||||||
Authorized([]string) bool
|
Authorized([]string) bool
|
||||||
|
|||||||
@@ -210,7 +210,7 @@ func TestAlloyDBSimpleToolEndpoints(t *testing.T) {
|
|||||||
name: "invoke my-simple-tool",
|
name: "invoke my-simple-tool",
|
||||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
||||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||||
want: "Stub tool call for \"my-simple-tool\"! Parameters parsed: [] \n Output: [%!s(int32=1)]",
|
want: "[{\"?column?\":1}]",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range invokeTcs {
|
for _, tc := range invokeTcs {
|
||||||
|
|||||||
@@ -62,11 +62,11 @@ func RunGoogleAuthenticatedParameterTest(t *testing.T, sourceConfig map[string]a
|
|||||||
var statement string
|
var statement string
|
||||||
switch {
|
switch {
|
||||||
case strings.EqualFold(toolKind, "postgres-sql"):
|
case strings.EqualFold(toolKind, "postgres-sql"):
|
||||||
statement = fmt.Sprintf("SELECT * FROM %s WHERE email = $1;", tableName)
|
statement = fmt.Sprintf("SELECT name FROM %s WHERE email = $1;", tableName)
|
||||||
case strings.EqualFold(toolKind, "mssql-sql"):
|
case strings.EqualFold(toolKind, "mssql-sql"):
|
||||||
statement = fmt.Sprintf("SELECT * FROM %s WHERE email = @email;", tableName)
|
statement = fmt.Sprintf("SELECT name FROM %s WHERE email = @email;", tableName)
|
||||||
case strings.EqualFold(toolKind, "mysql-sql"):
|
case strings.EqualFold(toolKind, "mysql-sql"):
|
||||||
statement = fmt.Sprintf("SELECT * FROM %s WHERE email = ?;", tableName)
|
statement = fmt.Sprintf("SELECT name FROM %s WHERE email = ?;", tableName)
|
||||||
default:
|
default:
|
||||||
t.Fatalf("invalid tool kind: %s", toolKind)
|
t.Fatalf("invalid tool kind: %s", toolKind)
|
||||||
}
|
}
|
||||||
@@ -131,14 +131,7 @@ func RunGoogleAuthenticatedParameterTest(t *testing.T, sourceConfig map[string]a
|
|||||||
t.Fatalf("error getting Google ID token: %s", err)
|
t.Fatalf("error getting Google ID token: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tools using database/sql interface only outputs `int64` instead of `int32`
|
wantString := "[{\"name\":\"Alice\"}]"
|
||||||
var wantString string
|
|
||||||
switch toolKind {
|
|
||||||
case "mssql-sql", "mysql-sql":
|
|
||||||
wantString = fmt.Sprintf("Stub tool call for \"my-auth-tool\"! Parameters parsed: [{\"email\" \"%s\"}] \n Output: [%%!s(int64=1) Alice %s]", SERVICE_ACCOUNT_EMAIL, SERVICE_ACCOUNT_EMAIL)
|
|
||||||
default:
|
|
||||||
wantString = fmt.Sprintf("Stub tool call for \"my-auth-tool\"! Parameters parsed: [{\"email\" \"%s\"}] \n Output: [%%!s(int32=1) Alice %s]", SERVICE_ACCOUNT_EMAIL, SERVICE_ACCOUNT_EMAIL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test tool invocation with authenticated parameters
|
// Test tool invocation with authenticated parameters
|
||||||
invokeTcs := []struct {
|
invokeTcs := []struct {
|
||||||
@@ -215,13 +208,14 @@ func RunGoogleAuthenticatedParameterTest(t *testing.T, sourceConfig map[string]a
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RunAuthRequiredToolInvocationTest(t *testing.T, sourceConfig map[string]any, toolKind string) {
|
func RunAuthRequiredToolInvocationTest(t *testing.T, sourceConfig map[string]any, toolKind string) {
|
||||||
// Tools using database/sql interface only outputs `int64` instead of `int32`
|
|
||||||
var wantString string
|
var wantString string
|
||||||
switch toolKind {
|
switch toolKind {
|
||||||
case "mssql-sql", "mysql-sql":
|
case "mysql-sql":
|
||||||
wantString = "Stub tool call for \"my-auth-tool\"! Parameters parsed: [] \n Output: [%!s(int64=1)]"
|
wantString = "[{\"1\":1}]"
|
||||||
|
case "mssql-sql":
|
||||||
|
wantString = "[{\"\":1}]"
|
||||||
default:
|
default:
|
||||||
wantString = "Stub tool call for \"my-auth-tool\"! Parameters parsed: [] \n Output: [%!s(int32=1)]"
|
wantString = "[{\"?column?\":1}]"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write config into a file and pass it to command
|
// Write config into a file and pass it to command
|
||||||
|
|||||||
@@ -199,7 +199,7 @@ func TestCloudSQLMssql(t *testing.T) {
|
|||||||
name: "invoke my-simple-tool",
|
name: "invoke my-simple-tool",
|
||||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
||||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||||
want: "Stub tool call for \"my-simple-tool\"! Parameters parsed: [] \n Output: [%!s(int64=1)]",
|
want: "[{\"\":1}]",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range invokeTcs {
|
for _, tc := range invokeTcs {
|
||||||
|
|||||||
@@ -193,7 +193,7 @@ func TestCloudSQLMySQL(t *testing.T) {
|
|||||||
name: "invoke my-simple-tool",
|
name: "invoke my-simple-tool",
|
||||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
||||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||||
want: "Stub tool call for \"my-simple-tool\"! Parameters parsed: [] \n Output: [%!s(int64=1)]",
|
want: "[{\"1\":1}]",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range invokeTcs {
|
for _, tc := range invokeTcs {
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ func TestCloudSQLPostgres(t *testing.T) {
|
|||||||
name: "invoke my-simple-tool",
|
name: "invoke my-simple-tool",
|
||||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
||||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||||
want: "Stub tool call for \"my-simple-tool\"! Parameters parsed: [] \n Output: [%!s(int32=1)]",
|
want: "[{\"?column?\":1}]",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range invokeTcs {
|
for _, tc := range invokeTcs {
|
||||||
|
|||||||
@@ -222,14 +222,7 @@ func RunToolInvocationWithParamsTest(t *testing.T, sourceConfig map[string]any,
|
|||||||
t.Fatalf("invalid tool kind: %s", toolKind)
|
t.Fatalf("invalid tool kind: %s", toolKind)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tools using database/sql interface only outputs `int64` instead of `int32`
|
wantString := "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]"
|
||||||
var wantString string
|
|
||||||
switch toolKind {
|
|
||||||
case "mssql-sql":
|
|
||||||
wantString = "Stub tool call for \"my-tool\"! Parameters parsed: [{\"id\" '\\x03'} {\"name\" \"Alice\"}] \n Output: [%!s(int64=1) Alice][%!s(int64=3) Sid]"
|
|
||||||
default:
|
|
||||||
wantString = "Stub tool call for \"my-tool\"! Parameters parsed: [{\"id\" '\\x03'} {\"name\" \"Alice\"}] \n Output: [%!s(int32=1) Alice][%!s(int32=3) Sid]"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write config into a file and pass it to command
|
// Write config into a file and pass it to command
|
||||||
toolsFile := map[string]any{
|
toolsFile := map[string]any{
|
||||||
|
|||||||
@@ -30,11 +30,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
DGRAPH_URL = os.Getenv("DGRAPH_URL")
|
DGRAPH_URL = os.Getenv("DGRAPH_URL")
|
||||||
)
|
)
|
||||||
|
|
||||||
func requireDgraphVars(t *testing.T) {
|
func requireDgraphVars(t *testing.T) {
|
||||||
if DGRAPH_URL =="" {
|
if DGRAPH_URL == "" {
|
||||||
t.Fatal("'DGRAPH_URL' not set")
|
t.Fatal("'DGRAPH_URL' not set")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -136,8 +136,7 @@ func TestDgraph(t *testing.T) {
|
|||||||
name: "invoke my-simple-dql-tool",
|
name: "invoke my-simple-dql-tool",
|
||||||
api: "http://127.0.0.1:5000/api/tool/my-simple-dql-tool/invoke",
|
api: "http://127.0.0.1:5000/api/tool/my-simple-dql-tool/invoke",
|
||||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||||
want: "Stub tool call for \"my-simple-dql-tool\"! Parameters parsed: map[]" +
|
want: "[{\"result\":[{\"constant\":1}]}]",
|
||||||
" \n Output: map[result:[map[constant:1]]]",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range invokeTcs {
|
for _, tc := range invokeTcs {
|
||||||
@@ -162,6 +161,7 @@ func TestDgraph(t *testing.T) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("unable to find result in response body")
|
t.Fatalf("unable to find result in response body")
|
||||||
}
|
}
|
||||||
|
|
||||||
if got != tc.want {
|
if got != tc.want {
|
||||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
MSSQL_DATABASE = os.Getenv("MSSQL_DATABASE")
|
MSSQL_DATABASE = os.Getenv("MSSQL_DATABASE")
|
||||||
MSSQL_HOST = os.Getenv("MSSQL_HOST")
|
MSSQL_HOST = os.Getenv("MSSQL_HOST")
|
||||||
MSSQL_PORT = os.Getenv("MSSQL_PORT")
|
MSSQL_PORT = os.Getenv("MSSQL_PORT")
|
||||||
MSSQL_USER = os.Getenv("MSSQL_USER")
|
MSSQL_USER = os.Getenv("MSSQL_USER")
|
||||||
@@ -149,7 +149,7 @@ func TestMsSQL(t *testing.T) {
|
|||||||
name: "invoke my-simple-tool",
|
name: "invoke my-simple-tool",
|
||||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
||||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||||
want: "Stub tool call for \"my-simple-tool\"! Parameters parsed: [] \n Output: [%!s(int64=1)]",
|
want: "[{\"\":1}]",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range invokeTcs {
|
for _, tc := range invokeTcs {
|
||||||
|
|||||||
@@ -149,7 +149,7 @@ func TestMySQL(t *testing.T) {
|
|||||||
name: "invoke my-simple-tool",
|
name: "invoke my-simple-tool",
|
||||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
||||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||||
want: "Stub tool call for \"my-simple-tool\"! Parameters parsed: [] \n Output: [%!s(int64=1)]",
|
want: "[{\"1\":1}]",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range invokeTcs {
|
for _, tc := range invokeTcs {
|
||||||
|
|||||||
@@ -145,7 +145,7 @@ func TestNeo4j(t *testing.T) {
|
|||||||
name: "invoke my-simple-cypher-tool",
|
name: "invoke my-simple-cypher-tool",
|
||||||
api: "http://127.0.0.1:5000/api/tool/my-simple-cypher-tool/invoke",
|
api: "http://127.0.0.1:5000/api/tool/my-simple-cypher-tool/invoke",
|
||||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||||
want: "Stub tool call for \"my-simple-cypher-tool\"! Parameters parsed: map[] \n Output: \n\ta: %!s(int64=1)\n",
|
want: "[{\"a\":1}]",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range invokeTcs {
|
for _, tc := range invokeTcs {
|
||||||
|
|||||||
@@ -149,7 +149,7 @@ func TestPostgres(t *testing.T) {
|
|||||||
name: "invoke my-simple-tool",
|
name: "invoke my-simple-tool",
|
||||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
||||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||||
want: "Stub tool call for \"my-simple-tool\"! Parameters parsed: [] \n Output: [%!s(int32=1)]",
|
want: "[{\"?column?\":1}]",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range invokeTcs {
|
for _, tc := range invokeTcs {
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ func TestSpanner(t *testing.T) {
|
|||||||
name: "invoke my-simple-tool",
|
name: "invoke my-simple-tool",
|
||||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
||||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||||
want: "Stub tool call for \"my-simple-tool\"! Parameters parsed: [] \n Output: {fields: [type:{code:INT64}], values: [string_value:\"1\"]}",
|
want: "[{\"\":\"1\"}]",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range invokeTcs {
|
for _, tc := range invokeTcs {
|
||||||
|
|||||||
Reference in New Issue
Block a user