fix(bigquery,mssql): fix panic on tools with array param (#722)

Fix: https://github.com/googleapis/genai-toolbox/issues/701

Things done:
1. Replace the `AsReversedMap()` helper with `AsMap()`
2. BigQuery's QueryParameter only accepts typed slices as input, but our
arrays are passed in as []any. Therefore, add a logic to convert []any
to a typed array based on the item type.

Tested on MCP inspector:
<img width="409" alt="Screenshot 2025-06-16 at 5 15 55 PM"
src="https://github.com/user-attachments/assets/8053cad5-270e-4d82-b97c-856238c42154"
/>

---------

Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>
This commit is contained in:
Wenxin Du
2025-06-25 22:54:26 -04:00
committed by GitHub
parent 184c681797
commit 7a6644cf0c
3 changed files with 74 additions and 28 deletions

View File

@@ -125,29 +125,38 @@ type Tool struct {
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
namedArgs := make([]bigqueryapi.QueryParameter, 0, len(params))
paramsMap := params.AsMap()
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {
return nil, fmt.Errorf("unable to extract template params %w", err)
}
newParams, err := tools.GetParams(t.Parameters, paramsMap)
if err != nil {
return nil, fmt.Errorf("unable to extract standard params %w", err)
}
for _, p := range t.Parameters {
name := p.GetName()
value := paramsMap[name]
namedArgs := make([]bigqueryapi.QueryParameter, 0, len(newParams))
newParamsMap := newParams.AsReversedMap()
for _, v := range newParams.AsSlice() {
paramName := newParamsMap[v]
if strings.Contains(newStatement, "@"+paramName) {
// BigQuery's QueryParameter only accepts typed slices as input
// This checks if the param is an array.
// If yes, convert []any to typed slice (e.g []string, []int)
switch arrayParam := value.(type) {
case []any:
var err error
itemType := p.McpManifest().Items.Type
value, err = convertAnySliceToTyped(arrayParam, itemType, name)
if err != nil {
return nil, fmt.Errorf("unable to convert []any to typed slice: %w", err)
}
}
if strings.Contains(t.Statement, "@"+name) {
namedArgs = append(namedArgs, bigqueryapi.QueryParameter{
Name: paramName,
Value: v,
Name: name,
Value: value,
})
} else {
namedArgs = append(namedArgs, bigqueryapi.QueryParameter{
Value: v,
Value: value,
})
}
}
@@ -196,3 +205,47 @@ func (t Tool) McpManifest() tools.McpManifest {
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func convertAnySliceToTyped(s []any, itemType, paramName string) (any, error) {
var typedSlice any
switch itemType {
case "string":
typedSlice := make([]string, len(s))
for j, item := range s {
if s, ok := item.(string); ok {
typedSlice[j] = s
} else {
return nil, fmt.Errorf("parameter '%s': expected item at index %d to be string, got %T", paramName, j, item)
}
}
case "integer":
typedSlice := make([]int64, len(s))
for j, item := range s {
i, ok := item.(int)
if !ok {
return nil, fmt.Errorf("parameter '%s': expected item at index %d to be integer, got %T", paramName, j, item)
}
typedSlice[j] = int64(i)
}
case "float":
typedSlice := make([]float64, len(s))
for j, item := range s {
if f, ok := item.(float64); ok {
typedSlice[j] = f
} else {
return nil, fmt.Errorf("parameter '%s': expected item at index %d to be float, got %T", paramName, j, item)
}
}
case "boolean":
typedSlice := make([]bool, len(s))
for j, item := range s {
if b, ok := item.(bool); ok {
typedSlice[j] = b
} else {
return nil, fmt.Errorf("parameter '%s': expected item at index %d to be boolean, got %T", paramName, j, item)
}
}
}
return typedSlice, nil
}

View File

@@ -138,16 +138,18 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, erro
}
namedArgs := make([]any, 0, len(newParams))
newParamsMap := newParams.AsReversedMap()
// To support both named args (e.g @id) and positional args (e.g @p1), check if arg name is contained in the statement.
for _, v := range newParams.AsSlice() {
paramName := newParamsMap[v]
if strings.Contains(newStatement, "@"+paramName) {
namedArgs = append(namedArgs, sql.Named(paramName, v))
// To support both named args (e.g @id) and positional args (e.g @p1), check
// if arg name is contained in the statement.
for _, p := range t.Parameters {
name := p.GetName()
value := paramsMap[name]
if strings.Contains(newStatement, "@"+name) {
namedArgs = append(namedArgs, sql.Named(name, value))
} else {
namedArgs = append(namedArgs, v)
namedArgs = append(namedArgs, value)
}
}
rows, err := t.Db.QueryContext(ctx, newStatement, namedArgs...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)

View File

@@ -62,15 +62,6 @@ func (p ParamValues) AsMap() map[string]interface{} {
return params
}
// AsReversedMap returns a map of ParamValue's values to names.
func (p ParamValues) AsReversedMap() map[any]string {
params := make(map[any]string)
for _, p := range p {
params[p.Value] = p.Name
}
return params
}
// AsMapByOrderedKeys returns a map of a key's position to it's value, as necessary for Spanner PSQL.
// Example { $1 -> "value1", $2 -> "value2" }
func (p ParamValues) AsMapByOrderedKeys() map[string]interface{} {