refactor: move source implementation in Invoke() function to Source (#2238)

Move source-related queries from `Invoke()` function into Source.

The following sources were updated in this PR:
* mssql
* cloudsql-mssql
* mysql
* cloudsql-mysql

This is an effort to generalizing tools to work with any Source that
implements a specific interface. This will provide a better segregation
of the roles for Tools vs Source.

Tool's role will be limited to the following:
* Resolve any pre-implementation steps or parameters (e.g. template
parameters)
* Retrieving Source
* Calling the source's implementation
This commit is contained in:
Yuan Teoh
2025-12-30 20:23:59 -08:00
committed by GitHub
parent 20447746e1
commit 55eb958c2a
14 changed files with 245 additions and 446 deletions

View File

@@ -25,6 +25,7 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/orderedmap"
"go.opentelemetry.io/otel/trace"
)
@@ -107,6 +108,48 @@ func (s *Source) MSSQLDB() *sql.DB {
return s.Db
}
func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) {
results, err := s.MSSQLDB().QueryContext(ctx, statement, params...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
defer results.Close()
cols, err := results.Columns()
// If Columns() errors, it might be a DDL/DML without an OUTPUT clause.
// We proceed, and results.Err() will catch actual query execution errors.
// 'out' will remain nil if cols is empty or err is not nil here.
var out []any
if err == nil && len(cols) > 0 {
// create an array of values for each column, which can be re-used to scan each row
rawValues := make([]any, len(cols))
values := make([]any, len(cols))
for i := range rawValues {
values[i] = &rawValues[i]
}
for results.Next() {
scanErr := results.Scan(values...)
if scanErr != nil {
return nil, fmt.Errorf("unable to parse row: %w", scanErr)
}
row := orderedmap.Row{}
for i, name := range cols {
row.Add(name, rawValues[i])
}
out = append(out, row)
}
}
// Check for errors from iterating over rows or from the query execution itself.
// results.Close() is handled by defer.
if err := results.Err(); err != nil {
return nil, fmt.Errorf("errors encountered during query execution or row processing: %w", err)
}
return out, nil
}
func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)

View File

@@ -24,7 +24,9 @@ import (
"cloud.google.com/go/cloudsqlconn/mysql/mysql"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/orderedmap"
"go.opentelemetry.io/otel/trace"
)
@@ -100,6 +102,60 @@ func (s *Source) MySQLPool() *sql.DB {
return s.Pool
}
func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) {
results, err := s.MySQLPool().QueryContext(ctx, statement, params...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
defer results.Close()
cols, err := results.Columns()
if err != nil {
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
}
// create an array of values for each column, which can be re-used to scan each row
rawValues := make([]any, len(cols))
values := make([]any, len(cols))
for i := range rawValues {
values[i] = &rawValues[i]
}
colTypes, err := results.ColumnTypes()
if err != nil {
return nil, fmt.Errorf("unable to get column types: %w", err)
}
var out []any
for results.Next() {
err := results.Scan(values...)
if err != nil {
return nil, fmt.Errorf("unable to parse row: %w", err)
}
row := orderedmap.Row{}
for i, name := range cols {
val := rawValues[i]
if val == nil {
row.Add(name, nil)
continue
}
convertedValue, err := mysqlcommon.ConvertToType(colTypes[i], val)
if err != nil {
return nil, fmt.Errorf("errors encountered when converting values: %w", err)
}
row.Add(name, convertedValue)
}
out = append(out, row)
}
if err := results.Err(); err != nil {
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
}
return out, nil
}
func getConnectionConfig(ctx context.Context, user, pass string) (string, string, bool, error) {
useIAM := true

View File

@@ -23,6 +23,7 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/orderedmap"
_ "github.com/microsoft/go-mssqldb"
"go.opentelemetry.io/otel/trace"
)
@@ -104,6 +105,48 @@ func (s *Source) MSSQLDB() *sql.DB {
return s.Db
}
func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) {
results, err := s.MSSQLDB().QueryContext(ctx, statement, params...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
defer results.Close()
cols, err := results.Columns()
// If Columns() errors, it might be a DDL/DML without an OUTPUT clause.
// We proceed, and results.Err() will catch actual query execution errors.
// 'out' will remain nil if cols is empty or err is not nil here.
var out []any
if err == nil && len(cols) > 0 {
// create an array of values for each column, which can be re-used to scan each row
rawValues := make([]any, len(cols))
values := make([]any, len(cols))
for i := range rawValues {
values[i] = &rawValues[i]
}
for results.Next() {
scanErr := results.Scan(values...)
if scanErr != nil {
return nil, fmt.Errorf("unable to parse row: %w", scanErr)
}
row := orderedmap.Row{}
for i, name := range cols {
row.Add(name, rawValues[i])
}
out = append(out, row)
}
}
// Check for errors from iterating over rows or from the query execution itself.
// results.Close() is handled by defer.
if err := results.Err(); err != nil {
return nil, fmt.Errorf("errors encountered during query execution or row processing: %w", err)
}
return out, nil
}
func initMssqlConnection(
ctx context.Context,
tracer trace.Tracer,

View File

@@ -24,7 +24,9 @@ import (
_ "github.com/go-sql-driver/mysql"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/orderedmap"
"go.opentelemetry.io/otel/trace"
)
@@ -100,6 +102,60 @@ func (s *Source) MySQLPool() *sql.DB {
return s.Pool
}
func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) {
results, err := s.MySQLPool().QueryContext(ctx, statement, params...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
defer results.Close()
cols, err := results.Columns()
if err != nil {
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
}
// create an array of values for each column, which can be re-used to scan each row
rawValues := make([]any, len(cols))
values := make([]any, len(cols))
for i := range rawValues {
values[i] = &rawValues[i]
}
colTypes, err := results.ColumnTypes()
if err != nil {
return nil, fmt.Errorf("unable to get column types: %w", err)
}
var out []any
for results.Next() {
err := results.Scan(values...)
if err != nil {
return nil, fmt.Errorf("unable to parse row: %w", err)
}
row := orderedmap.Row{}
for i, name := range cols {
val := rawValues[i]
if val == nil {
row.Add(name, nil)
continue
}
convertedValue, err := mysqlcommon.ConvertToType(colTypes[i], val)
if err != nil {
return nil, fmt.Errorf("errors encountered when converting values: %w", err)
}
row.Add(name, convertedValue)
}
out = append(out, row)
}
if err := results.Err(); err != nil {
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
}
return out, nil
}
func initMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string, queryParams map[string]string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)

View File

@@ -23,7 +23,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/orderedmap"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -45,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
MSSQLDB() *sql.DB
RunSQL(context.Context, string, []any) (any, error)
}
type Config struct {
@@ -106,47 +106,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("error getting logger: %s", err)
}
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql))
results, err := source.MSSQLDB().QueryContext(ctx, sql)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
defer results.Close()
cols, err := results.Columns()
// If Columns() errors, it might be a DDL/DML without an OUTPUT clause.
// We proceed, and results.Err() will catch actual query execution errors.
// 'out' will remain nil if cols is empty or err is not nil here.
var out []any
if err == nil && len(cols) > 0 {
// create an array of values for each column, which can be re-used to scan each row
rawValues := make([]any, len(cols))
values := make([]any, len(cols))
for i := range rawValues {
values[i] = &rawValues[i]
}
for results.Next() {
scanErr := results.Scan(values...)
if scanErr != nil {
return nil, fmt.Errorf("unable to parse row: %w", scanErr)
}
row := orderedmap.Row{}
for i, name := range cols {
row.Add(name, rawValues[i])
}
out = append(out, row)
}
}
// Check for errors from iterating over rows or from the query execution itself.
// results.Close() is handled by defer.
if err := results.Err(); err != nil {
return nil, fmt.Errorf("errors encountered during query execution or row processing: %w", err)
}
return out, nil
return source.RunSQL(ctx, sql, nil)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -292,6 +292,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
MSSQLDB() *sql.DB
RunSQL(context.Context, string, []any) (any, error)
}
type Config struct {
@@ -354,44 +355,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
sql.Named("table_names", paramsMap["table_names"]),
sql.Named("output_format", outputFormat),
}
rows, err := source.MSSQLDB().QueryContext(ctx, listTablesStatement, namedArgs...)
resp, err := source.RunSQL(ctx, listTablesStatement, namedArgs)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
return nil, err
}
defer rows.Close()
cols, err := rows.Columns()
if err != nil {
return nil, fmt.Errorf("unable to fetch column names: %w", err)
// if there's no results, return empty list instead of null
resSlice, ok := resp.([]any)
if !ok || len(resSlice) == 0 {
return []any{}, nil
}
// create an array of values for each column, which can be re-used to scan each row
rawValues := make([]any, len(cols))
values := make([]any, len(cols))
for i := range rawValues {
values[i] = &rawValues[i]
}
out := []any{}
for rows.Next() {
err = rows.Scan(values...)
if err != nil {
return nil, fmt.Errorf("unable to parse row: %w", err)
}
vMap := make(map[string]any)
for i, name := range cols {
vMap[name] = rawValues[i]
}
out = append(out, vMap)
}
// Check if error occurred during iteration
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
}
return out, nil
return resp, err
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -44,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
MSSQLDB() *sql.DB
RunSQL(context.Context, string, []any) (any, error)
}
type Config struct {
@@ -121,47 +122,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
namedArgs = append(namedArgs, value)
}
}
rows, err := source.MSSQLDB().QueryContext(ctx, newStatement, namedArgs...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
cols, err := rows.Columns()
if err != nil {
return nil, fmt.Errorf("unable to fetch column types: %w", err)
}
// create an array of values for each column, which can be re-used to scan each row
rawValues := make([]any, len(cols))
values := make([]any, len(cols))
for i := range rawValues {
values[i] = &rawValues[i]
}
var out []any
for rows.Next() {
err = rows.Scan(values...)
if err != nil {
return nil, fmt.Errorf("unable to parse row: %w", err)
}
vMap := make(map[string]any)
for i, name := range cols {
vMap[name] = rawValues[i]
}
out = append(out, vMap)
}
err = rows.Close()
if err != nil {
return nil, fmt.Errorf("unable to close rows: %w", err)
}
// Check if error occurred during iteration
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
return source.RunSQL(ctx, newStatement, namedArgs)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -22,9 +22,7 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/orderedmap"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -46,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
MySQLPool() *sql.DB
RunSQL(context.Context, string, []any) (any, error)
}
type Config struct {
@@ -107,58 +106,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("error getting logger: %s", err)
}
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql))
results, err := source.MySQLPool().QueryContext(ctx, sql)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
defer results.Close()
cols, err := results.Columns()
if err != nil {
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
}
// create an array of values for each column, which can be re-used to scan each row
rawValues := make([]any, len(cols))
values := make([]any, len(cols))
for i := range rawValues {
values[i] = &rawValues[i]
}
colTypes, err := results.ColumnTypes()
if err != nil {
return nil, fmt.Errorf("unable to get column types: %w", err)
}
var out []any
for results.Next() {
err := results.Scan(values...)
if err != nil {
return nil, fmt.Errorf("unable to parse row: %w", err)
}
row := orderedmap.Row{}
for i, name := range cols {
val := rawValues[i]
if val == nil {
row.Add(name, nil)
continue
}
convertedValue, err := mysqlcommon.ConvertToType(colTypes[i], val)
if err != nil {
return nil, fmt.Errorf("errors encountered when converting values: %w", err)
}
row.Add(name, convertedValue)
}
out = append(out, row)
}
if err := results.Err(); err != nil {
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
}
return out, nil
return source.RunSQL(ctx, sql, nil)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -24,6 +24,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/orderedmap"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -45,6 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
MySQLPool() *sql.DB
RunSQL(context.Context, string, []any) (any, error)
}
type Config struct {
@@ -108,30 +110,27 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql))
query := fmt.Sprintf("EXPLAIN FORMAT=JSON %s", sql)
results, err := source.MySQLPool().QueryContext(ctx, query)
result, err := source.RunSQL(ctx, query, nil)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
return nil, err
}
defer results.Close()
var plan string
if results.Next() {
if err := results.Scan(&plan); err != nil {
return nil, fmt.Errorf("unable to parse row: %w", err)
}
} else {
// extract and return only the query plan object
resSlice, ok := result.([]any)
if !ok || len(resSlice) == 0 {
return nil, fmt.Errorf("no query plan returned")
}
if err := results.Err(); err != nil {
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
row, ok := resSlice[0].(orderedmap.Row)
if !ok || len(row.Columns) == 0 {
return nil, fmt.Errorf("no query plan returned in row")
}
var out any
plan, ok := row.Columns[0].Value.(string)
if !ok {
return nil, fmt.Errorf("unable to convert plan object to string")
}
var out map[string]any
if err := json.Unmarshal([]byte(plan), &out); err != nil {
return nil, fmt.Errorf("failed to unmarshal query plan json: %w", err)
}
return out, nil
}

View File

@@ -24,7 +24,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
"github.com/googleapis/genai-toolbox/internal/sources/mysql"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -109,6 +108,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
MySQLPool() *sql.DB
RunSQL(context.Context, string, []any) (any, error)
}
type Config struct {
@@ -146,14 +146,13 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var statement string
sourceKind := rawS.SourceKind()
switch sourceKind {
case mysql.SourceKind:
statement = listActiveQueriesStatementMySQL
case cloudsqlmysql.SourceKind:
statement = listActiveQueriesStatementCloudSQLMySQL
default:
return nil, fmt.Errorf("unsupported source kind kind: %q", sourceKind)
return nil, fmt.Errorf("unsupported source kind: %s", cfg.Source)
}
// finish tool setup
t := Tool{
@@ -200,57 +199,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("error getting logger: %s", err)
}
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, t.statement))
results, err := source.MySQLPool().QueryContext(ctx, t.statement, duration, duration, limit)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
defer results.Close()
cols, err := results.Columns()
if err != nil {
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
}
// create an array of values for each column, which can be re-used to scan each row
rawValues := make([]any, len(cols))
values := make([]any, len(cols))
for i := range rawValues {
values[i] = &rawValues[i]
}
colTypes, err := results.ColumnTypes()
if err != nil {
return nil, fmt.Errorf("unable to get column types: %w", err)
}
var out []any
for results.Next() {
err := results.Scan(values...)
if err != nil {
return nil, fmt.Errorf("unable to parse row: %w", err)
}
vMap := make(map[string]any)
for i, name := range cols {
val := rawValues[i]
if val == nil {
vMap[name] = nil
continue
}
vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val)
if err != nil {
return nil, fmt.Errorf("errors encountered when converting values: %w", err)
}
}
out = append(out, vMap)
}
if err := results.Err(); err != nil {
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
}
return out, nil
return source.RunSQL(ctx, t.statement, []any{duration, duration, limit})
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -22,7 +22,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -67,6 +66,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
MySQLPool() *sql.DB
RunSQL(context.Context, string, []any) (any, error)
}
type Config struct {
@@ -144,57 +144,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("error getting logger: %s", err)
}
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, listTableFragmentationStatement))
results, err := source.MySQLPool().QueryContext(ctx, listTableFragmentationStatement, table_schema, table_schema, table_name, table_name, data_free_threshold_bytes, limit)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
defer results.Close()
cols, err := results.Columns()
if err != nil {
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
}
// create an array of values for each column, which can be re-used to scan each row
rawValues := make([]any, len(cols))
values := make([]any, len(cols))
for i := range rawValues {
values[i] = &rawValues[i]
}
colTypes, err := results.ColumnTypes()
if err != nil {
return nil, fmt.Errorf("unable to get column types: %w", err)
}
var out []any
for results.Next() {
err := results.Scan(values...)
if err != nil {
return nil, fmt.Errorf("unable to parse row: %w", err)
}
vMap := make(map[string]any)
for i, name := range cols {
val := rawValues[i]
if val == nil {
vMap[name] = nil
continue
}
vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val)
if err != nil {
return nil, fmt.Errorf("errors encountered when converting values: %w", err)
}
}
out = append(out, vMap)
}
if err := results.Err(); err != nil {
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
}
return out, nil
sliceParams := []any{table_schema, table_schema, table_name, table_name, data_free_threshold_bytes, limit}
return source.RunSQL(ctx, listTableFragmentationStatement, sliceParams)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -22,7 +22,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -197,6 +196,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
MySQLPool() *sql.DB
RunSQL(context.Context, string, []any) (any, error)
}
type Config struct {
@@ -259,57 +259,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if outputFormat != "simple" && outputFormat != "detailed" {
return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat)
}
results, err := source.MySQLPool().QueryContext(ctx, listTablesStatement, tableNames, outputFormat)
resp, err := source.RunSQL(ctx, listTablesStatement, []any{tableNames, outputFormat})
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
return nil, err
}
cols, err := results.Columns()
if err != nil {
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
// if there's no results, return empty list instead of null
resSlice, ok := resp.([]any)
if !ok || len(resSlice) == 0 {
return []any{}, nil
}
// create an array of values for each column, which can be re-used to scan each row
rawValues := make([]any, len(cols))
values := make([]any, len(cols))
for i := range rawValues {
values[i] = &rawValues[i]
}
defer results.Close()
colTypes, err := results.ColumnTypes()
if err != nil {
return nil, fmt.Errorf("unable to get column types: %w", err)
}
out := []any{}
for results.Next() {
err := results.Scan(values...)
if err != nil {
return nil, fmt.Errorf("unable to parse row: %w", err)
}
vMap := make(map[string]any)
for i, name := range cols {
val := rawValues[i]
if val == nil {
vMap[name] = nil
continue
}
vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val)
if err != nil {
return nil, fmt.Errorf("errors encountered when converting values: %w", err)
}
}
out = append(out, vMap)
}
if err := results.Err(); err != nil {
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
}
return out, nil
return resp, err
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -22,7 +22,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -68,6 +67,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
MySQLPool() *sql.DB
RunSQL(context.Context, string, []any) (any, error)
}
type Config struct {
@@ -135,57 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("error getting logger: %s", err)
}
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, listTablesMissingUniqueIndexesStatement))
results, err := source.MySQLPool().QueryContext(ctx, listTablesMissingUniqueIndexesStatement, table_schema, table_schema, limit)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
defer results.Close()
cols, err := results.Columns()
if err != nil {
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
}
// create an array of values for each column, which can be re-used to scan each row
rawValues := make([]any, len(cols))
values := make([]any, len(cols))
for i := range rawValues {
values[i] = &rawValues[i]
}
colTypes, err := results.ColumnTypes()
if err != nil {
return nil, fmt.Errorf("unable to get column types: %w", err)
}
var out []any
for results.Next() {
err := results.Scan(values...)
if err != nil {
return nil, fmt.Errorf("unable to parse row: %w", err)
}
vMap := make(map[string]any)
for i, name := range cols {
val := rawValues[i]
if val == nil {
vMap[name] = nil
continue
}
vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val)
if err != nil {
return nil, fmt.Errorf("errors encountered when converting values: %w", err)
}
}
out = append(out, vMap)
}
if err := results.Err(); err != nil {
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
}
return out, nil
return source.RunSQL(ctx, listTablesMissingUniqueIndexesStatement, []any{table_schema, table_schema, limit})
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -22,7 +22,6 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
@@ -44,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
MySQLPool() *sql.DB
RunSQL(context.Context, string, []any) (any, error)
}
type Config struct {
@@ -110,56 +110,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
sliceParams := newParams.AsSlice()
results, err := source.MySQLPool().QueryContext(ctx, newStatement, sliceParams...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
cols, err := results.Columns()
if err != nil {
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
}
// create an array of values for each column, which can be re-used to scan each row
rawValues := make([]any, len(cols))
values := make([]any, len(cols))
for i := range rawValues {
values[i] = &rawValues[i]
}
defer results.Close()
colTypes, err := results.ColumnTypes()
if err != nil {
return nil, fmt.Errorf("unable to get column types: %w", err)
}
var out []any
for results.Next() {
err := results.Scan(values...)
if err != nil {
return nil, fmt.Errorf("unable to parse row: %w", err)
}
vMap := make(map[string]any)
for i, name := range cols {
val := rawValues[i]
if val == nil {
vMap[name] = nil
continue
}
vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val)
if err != nil {
return nil, fmt.Errorf("errors encountered when converting values: %w", err)
}
}
out = append(out, vMap)
}
if err := results.Err(); err != nil {
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
}
return out, nil
return source.RunSQL(ctx, newStatement, sliceParams)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {