fix: update tool invoke to return json (#266)

Return actual rows as `[]any` that contains `map` of results. Each `map`
represent a row, with the key being column name.
This commit is contained in:
Yuan
2025-02-05 13:45:01 -08:00
committed by GitHub
parent 1702f74e99
commit ad58cd5855
23 changed files with 129 additions and 107 deletions

View File

@@ -40,3 +40,6 @@ run:
- cloudsqlmssql
- cloudsqlmysql
- neo4j
- dgraph
- mssql
- mysql

View File

@@ -302,7 +302,7 @@ func TestParseToolFile(t *testing.T) {
Instance: "my-instance",
IPType: "public",
Database: "my_db",
User: "my_user",
User: "my_user",
Password: "my_pass",
},
},
@@ -416,7 +416,7 @@ func TestParseToolFileWithAuth(t *testing.T) {
Instance: "my-instance",
IPType: "public",
Database: "my_db",
User: "my_user",
User: "my_user",
Password: "my_pass",
},
},

View File

@@ -220,7 +220,15 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
return
}
_ = render.Render(w, r, &resultResponse{Result: res})
resMarshal, err := json.Marshal(res)
if err != nil {
err = fmt.Errorf("unable to marshal result: %w", err)
s.logger.DebugContext(context.Background(), err.Error())
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
return
}
_ = render.Render(w, r, &resultResponse{Result: string(resMarshal)})
}
var _ render.Renderer = &resultResponse{} // Renderer interface for managing response payloads.

View File

@@ -39,8 +39,9 @@ type MockTool struct {
Params []tools.Parameter
}
func (t MockTool) Invoke(tools.ParamValues) (string, error) {
return "", nil
func (t MockTool) Invoke(tools.ParamValues) ([]any, error) {
mock := make([]any, 0)
return mock, nil
}
// claims is a map of user info decoded from an auth token

View File

@@ -94,30 +94,29 @@ type Tool struct {
manifest tools.Manifest
}
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
paramsMap := params.AsMapWithDollarPrefix()
resp, err := t.DgraphClient.ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout)
if err != nil {
return "", err
return nil, err
}
if err := dgraph.CheckError(resp); err != nil {
return "", err
return nil, err
}
var out []any
var result struct {
Data map[string]interface{} `json:"data"`
}
if err := json.Unmarshal(resp, &result); err != nil {
return "", fmt.Errorf("error parsing JSON: %v", err)
return nil, fmt.Errorf("error parsing JSON: %v", err)
}
out = append(out, result.Data)
return fmt.Sprintf(
"Stub tool call for %q! Parameters parsed: %q \n Output: %v",
t.Name, paramsMap, result.Data,
), nil
return out, nil
}
func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) {

View File

@@ -107,9 +107,7 @@ type Tool struct {
manifest tools.Manifest
}
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
fmt.Printf("Invoked tool %s\n", t.Name)
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
namedArgs := make([]any, 0, len(params))
paramsMap := params.AsReversedMap()
// To support both named args (e.g @id) and positional args (e.g @p1), check if arg name is contained in the statement.
@@ -123,39 +121,44 @@ func (t Tool) Invoke(params tools.ParamValues) (string, error) {
}
rows, err := t.Db.QueryContext(context.Background(), t.Statement, namedArgs...)
if err != nil {
return "", fmt.Errorf("unable to execute query: %w", err)
return nil, fmt.Errorf("unable to execute query: %w", err)
}
types, err := rows.ColumnTypes()
cols, err := rows.Columns()
if err != nil {
return "", fmt.Errorf("unable to fetch column types: %w", err)
}
v := make([]any, len(types))
pointers := make([]any, len(types))
for i := range types {
pointers[i] = &v[i]
return nil, fmt.Errorf("unable to fetch column types: %w", err)
}
// fetch result into a string
var out strings.Builder
// create an array of values for each column, which can be re-used to scan each row
rawValues := make([]any, len(cols))
values := make([]any, len(cols))
for i := range rawValues {
values[i] = &rawValues[i]
}
var out []any
for rows.Next() {
err = rows.Scan(pointers...)
err = rows.Scan(values...)
if err != nil {
return "", fmt.Errorf("unable to parse row: %w", err)
return nil, fmt.Errorf("unable to parse row: %w", err)
}
out.WriteString(fmt.Sprintf("%s", v))
vMap := make(map[string]any)
for i, name := range cols {
vMap[name] = rawValues[i]
}
out = append(out, vMap)
}
err = rows.Close()
if err != nil {
return "", fmt.Errorf("unable to close rows: %w", err)
return nil, fmt.Errorf("unable to close rows: %w", err)
}
// Check if error occured during iteration
if err := rows.Err(); err != nil {
return "", err
return nil, err
}
return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q \n Output: %s", t.Name, params, out.String()), nil
return out, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {

View File

@@ -18,7 +18,6 @@ import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
@@ -107,44 +106,54 @@ type Tool struct {
manifest tools.Manifest
}
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
sliceParams := params.AsSlice()
results, err := t.Pool.QueryContext(context.Background(), t.Statement, sliceParams...)
if err != nil {
return "", fmt.Errorf("unable to execute query: %w", err)
return nil, fmt.Errorf("unable to execute query: %w", err)
}
cols, err := results.Columns()
if err != nil {
return "", fmt.Errorf("unable to retrieve rows column name: %w", err)
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
}
cl := len(cols)
v := make([]any, cl)
pointers := make([]any, cl)
for i := range v {
pointers[i] = &v[i]
// create an array of values for each column, which can be re-used to scan each row
rawValues := make([]any, len(cols))
values := make([]any, len(cols))
for i := range rawValues {
values[i] = &rawValues[i]
}
var out strings.Builder
var out []any
for results.Next() {
err := results.Scan(pointers...)
err := results.Scan(values...)
if err != nil {
return "", fmt.Errorf("unable to parse row: %w", err)
return nil, fmt.Errorf("unable to parse row: %w", err)
}
out.WriteString(fmt.Sprintf("%s", v))
vMap := make(map[string]any)
for i, name := range cols {
b, ok := rawValues[i].([]byte)
if ok {
vMap[name] = string(b)
} else {
vMap[name] = rawValues[i]
}
}
out = append(out, vMap)
}
err = results.Close()
if err != nil {
return "", fmt.Errorf("unable to close rows: %w", err)
return nil, fmt.Errorf("unable to close rows: %w", err)
}
if err := results.Err(); err != nil {
return "", fmt.Errorf("errors encountered by results.Scan: %w", err)
return nil, fmt.Errorf("errors encountered by results.Scan: %w", err)
}
return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q \n Output: %s", t.Name, params, out.String()), nil
return out, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {

View File

@@ -17,7 +17,6 @@ package neo4j
import (
"context"
"fmt"
"strings"
neo4jsc "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
@@ -95,28 +94,29 @@ type Tool struct {
manifest tools.Manifest
}
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
paramsMap := params.AsMap()
fmt.Printf("Invoked tool %s\n", t.Name)
ctx := context.Background()
config := neo4j.ExecuteQueryWithDatabase(t.Database)
results, err := neo4j.ExecuteQuery[*neo4j.EagerResult](ctx, t.Driver, t.Statement, paramsMap,
neo4j.EagerResultTransformer, config)
if err != nil {
return "", fmt.Errorf("unable to execute query: %w", err)
return nil, fmt.Errorf("unable to execute query: %w", err)
}
var out strings.Builder
var out []any
keys := results.Keys
records := results.Records
for _, record := range records {
out.WriteString("\n") // fmt.Sprintf("Row: %d\n", row))
vMap := make(map[string]any)
for col, value := range record.Values {
out.WriteString(fmt.Sprintf("\t%s: %s\n", keys[col], value))
vMap[keys[col]] = value
}
out = append(out, vMap)
}
return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q \n Output: %s", t.Name, paramsMap, out.String()), nil
return out, nil
}
func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) {

View File

@@ -17,7 +17,6 @@ package postgressql
import (
"context"
"fmt"
"strings"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
@@ -109,23 +108,29 @@ type Tool struct {
manifest tools.Manifest
}
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
sliceParams := params.AsSlice()
results, err := t.Pool.Query(context.Background(), t.Statement, sliceParams...)
if err != nil {
return "", fmt.Errorf("unable to execute query: %w", err)
return nil, fmt.Errorf("unable to execute query: %w", err)
}
var out strings.Builder
fields := results.FieldDescriptions()
var out []any
for results.Next() {
v, err := results.Values()
if err != nil {
return "", fmt.Errorf("unable to parse row: %w", err)
return nil, fmt.Errorf("unable to parse row: %w", err)
}
out.WriteString(fmt.Sprintf("%s", v))
vMap := make(map[string]any)
for i, f := range fields {
vMap[f.Name] = v[i]
}
out = append(out, vMap)
}
return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q \n Output: %s", t.Name, params, out.String()), nil
return out, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {

View File

@@ -121,13 +121,13 @@ func getMapParams(params tools.ParamValues, dialect string) (map[string]interfac
}
}
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
mapParams, err := getMapParams(params, t.dialect)
if err != nil {
return "", fmt.Errorf("fail to get map params: %w", err)
return nil, fmt.Errorf("fail to get map params: %w", err)
}
var out strings.Builder
var out []any
_, err = t.Client.ReadWriteTransaction(context.Background(), func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
stmt := spanner.Statement{
@@ -145,14 +145,21 @@ func (t Tool) Invoke(params tools.ParamValues) (string, error) {
if err != nil {
return fmt.Errorf("unable to parse row: %w", err)
}
out.WriteString(row.String())
vMap := make(map[string]any)
cols := row.ColumnNames()
for i, c := range cols {
vMap[c] = row.ColumnValue(i)
}
out = append(out, vMap)
}
})
if err != nil {
return "", fmt.Errorf("unable to execute client: %w", err)
return nil, fmt.Errorf("unable to execute client: %w", err)
}
return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q \n Output: %s", t.Name, params, out.String()), nil
return out, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {

View File

@@ -26,7 +26,7 @@ type ToolConfig interface {
}
type Tool interface {
Invoke(ParamValues) (string, error)
Invoke(ParamValues) ([]any, error)
ParseParams(map[string]any, map[string]map[string]any) (ParamValues, error)
Manifest() Manifest
Authorized([]string) bool

View File

@@ -210,7 +210,7 @@ func TestAlloyDBSimpleToolEndpoints(t *testing.T) {
name: "invoke my-simple-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{}`)),
want: "Stub tool call for \"my-simple-tool\"! Parameters parsed: [] \n Output: [%!s(int32=1)]",
want: "[{\"?column?\":1}]",
},
}
for _, tc := range invokeTcs {

View File

@@ -62,11 +62,11 @@ func RunGoogleAuthenticatedParameterTest(t *testing.T, sourceConfig map[string]a
var statement string
switch {
case strings.EqualFold(toolKind, "postgres-sql"):
statement = fmt.Sprintf("SELECT * FROM %s WHERE email = $1;", tableName)
statement = fmt.Sprintf("SELECT name FROM %s WHERE email = $1;", tableName)
case strings.EqualFold(toolKind, "mssql-sql"):
statement = fmt.Sprintf("SELECT * FROM %s WHERE email = @email;", tableName)
statement = fmt.Sprintf("SELECT name FROM %s WHERE email = @email;", tableName)
case strings.EqualFold(toolKind, "mysql-sql"):
statement = fmt.Sprintf("SELECT * FROM %s WHERE email = ?;", tableName)
statement = fmt.Sprintf("SELECT name FROM %s WHERE email = ?;", tableName)
default:
t.Fatalf("invalid tool kind: %s", toolKind)
}
@@ -131,14 +131,7 @@ func RunGoogleAuthenticatedParameterTest(t *testing.T, sourceConfig map[string]a
t.Fatalf("error getting Google ID token: %s", err)
}
// Tools using database/sql interface only outputs `int64` instead of `int32`
var wantString string
switch toolKind {
case "mssql-sql", "mysql-sql":
wantString = fmt.Sprintf("Stub tool call for \"my-auth-tool\"! Parameters parsed: [{\"email\" \"%s\"}] \n Output: [%%!s(int64=1) Alice %s]", SERVICE_ACCOUNT_EMAIL, SERVICE_ACCOUNT_EMAIL)
default:
wantString = fmt.Sprintf("Stub tool call for \"my-auth-tool\"! Parameters parsed: [{\"email\" \"%s\"}] \n Output: [%%!s(int32=1) Alice %s]", SERVICE_ACCOUNT_EMAIL, SERVICE_ACCOUNT_EMAIL)
}
wantString := "[{\"name\":\"Alice\"}]"
// Test tool invocation with authenticated parameters
invokeTcs := []struct {
@@ -215,13 +208,14 @@ func RunGoogleAuthenticatedParameterTest(t *testing.T, sourceConfig map[string]a
}
func RunAuthRequiredToolInvocationTest(t *testing.T, sourceConfig map[string]any, toolKind string) {
// Tools using database/sql interface only outputs `int64` instead of `int32`
var wantString string
switch toolKind {
case "mssql-sql", "mysql-sql":
wantString = "Stub tool call for \"my-auth-tool\"! Parameters parsed: [] \n Output: [%!s(int64=1)]"
case "mysql-sql":
wantString = "[{\"1\":1}]"
case "mssql-sql":
wantString = "[{\"\":1}]"
default:
wantString = "Stub tool call for \"my-auth-tool\"! Parameters parsed: [] \n Output: [%!s(int32=1)]"
wantString = "[{\"?column?\":1}]"
}
// Write config into a file and pass it to command

View File

@@ -199,7 +199,7 @@ func TestCloudSQLMssql(t *testing.T) {
name: "invoke my-simple-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{}`)),
want: "Stub tool call for \"my-simple-tool\"! Parameters parsed: [] \n Output: [%!s(int64=1)]",
want: "[{\"\":1}]",
},
}
for _, tc := range invokeTcs {

View File

@@ -193,7 +193,7 @@ func TestCloudSQLMySQL(t *testing.T) {
name: "invoke my-simple-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{}`)),
want: "Stub tool call for \"my-simple-tool\"! Parameters parsed: [] \n Output: [%!s(int64=1)]",
want: "[{\"1\":1}]",
},
}
for _, tc := range invokeTcs {

View File

@@ -197,7 +197,7 @@ func TestCloudSQLPostgres(t *testing.T) {
name: "invoke my-simple-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{}`)),
want: "Stub tool call for \"my-simple-tool\"! Parameters parsed: [] \n Output: [%!s(int32=1)]",
want: "[{\"?column?\":1}]",
},
}
for _, tc := range invokeTcs {

View File

@@ -222,14 +222,7 @@ func RunToolInvocationWithParamsTest(t *testing.T, sourceConfig map[string]any,
t.Fatalf("invalid tool kind: %s", toolKind)
}
// Tools using database/sql interface only outputs `int64` instead of `int32`
var wantString string
switch toolKind {
case "mssql-sql":
wantString = "Stub tool call for \"my-tool\"! Parameters parsed: [{\"id\" '\\x03'} {\"name\" \"Alice\"}] \n Output: [%!s(int64=1) Alice][%!s(int64=3) Sid]"
default:
wantString = "Stub tool call for \"my-tool\"! Parameters parsed: [{\"id\" '\\x03'} {\"name\" \"Alice\"}] \n Output: [%!s(int32=1) Alice][%!s(int32=3) Sid]"
}
wantString := "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]"
// Write config into a file and pass it to command
toolsFile := map[string]any{

View File

@@ -30,11 +30,11 @@ import (
)
var (
DGRAPH_URL = os.Getenv("DGRAPH_URL")
DGRAPH_URL = os.Getenv("DGRAPH_URL")
)
func requireDgraphVars(t *testing.T) {
if DGRAPH_URL =="" {
if DGRAPH_URL == "" {
t.Fatal("'DGRAPH_URL' not set")
}
}
@@ -136,8 +136,7 @@ func TestDgraph(t *testing.T) {
name: "invoke my-simple-dql-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-dql-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{}`)),
want: "Stub tool call for \"my-simple-dql-tool\"! Parameters parsed: map[]" +
" \n Output: map[result:[map[constant:1]]]",
want: "[{\"result\":[{\"constant\":1}]}]",
},
}
for _, tc := range invokeTcs {
@@ -162,6 +161,7 @@ func TestDgraph(t *testing.T) {
if !ok {
t.Fatalf("unable to find result in response body")
}
if got != tc.want {
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
}

View File

@@ -30,7 +30,7 @@ import (
)
var (
MSSQL_DATABASE = os.Getenv("MSSQL_DATABASE")
MSSQL_DATABASE = os.Getenv("MSSQL_DATABASE")
MSSQL_HOST = os.Getenv("MSSQL_HOST")
MSSQL_PORT = os.Getenv("MSSQL_PORT")
MSSQL_USER = os.Getenv("MSSQL_USER")
@@ -149,7 +149,7 @@ func TestMsSQL(t *testing.T) {
name: "invoke my-simple-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{}`)),
want: "Stub tool call for \"my-simple-tool\"! Parameters parsed: [] \n Output: [%!s(int64=1)]",
want: "[{\"\":1}]",
},
}
for _, tc := range invokeTcs {

View File

@@ -149,7 +149,7 @@ func TestMySQL(t *testing.T) {
name: "invoke my-simple-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{}`)),
want: "Stub tool call for \"my-simple-tool\"! Parameters parsed: [] \n Output: [%!s(int64=1)]",
want: "[{\"1\":1}]",
},
}
for _, tc := range invokeTcs {

View File

@@ -145,7 +145,7 @@ func TestNeo4j(t *testing.T) {
name: "invoke my-simple-cypher-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-cypher-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{}`)),
want: "Stub tool call for \"my-simple-cypher-tool\"! Parameters parsed: map[] \n Output: \n\ta: %!s(int64=1)\n",
want: "[{\"a\":1}]",
},
}
for _, tc := range invokeTcs {

View File

@@ -149,7 +149,7 @@ func TestPostgres(t *testing.T) {
name: "invoke my-simple-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{}`)),
want: "Stub tool call for \"my-simple-tool\"! Parameters parsed: [] \n Output: [%!s(int32=1)]",
want: "[{\"?column?\":1}]",
},
}
for _, tc := range invokeTcs {

View File

@@ -141,7 +141,7 @@ func TestSpanner(t *testing.T) {
name: "invoke my-simple-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{}`)),
want: "Stub tool call for \"my-simple-tool\"! Parameters parsed: [] \n Output: {fields: [type:{code:INT64}], values: [string_value:\"1\"]}",
want: "[{\"\":\"1\"}]",
},
}
for _, tc := range invokeTcs {