diff --git a/internal/sources/cloudsqlmssql/cloud_sql_mssql.go b/internal/sources/cloudsqlmssql/cloud_sql_mssql.go index 1435165fde..02480df326 100644 --- a/internal/sources/cloudsqlmssql/cloud_sql_mssql.go +++ b/internal/sources/cloudsqlmssql/cloud_sql_mssql.go @@ -25,6 +25,7 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "go.opentelemetry.io/otel/trace" ) @@ -107,6 +108,48 @@ func (s *Source) MSSQLDB() *sql.DB { return s.Db } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + results, err := s.MSSQLDB().QueryContext(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + cols, err := results.Columns() + // If Columns() errors, it might be a DDL/DML without an OUTPUT clause. + // We proceed, and results.Err() will catch actual query execution errors. + // 'out' will remain nil if cols is empty or err is not nil here. + var out []any + if err == nil && len(cols) > 0 { + // create an array of values for each column, which can be re-used to scan each row + rawValues := make([]any, len(cols)) + values := make([]any, len(cols)) + for i := range rawValues { + values[i] = &rawValues[i] + } + + for results.Next() { + scanErr := results.Scan(values...) + if scanErr != nil { + return nil, fmt.Errorf("unable to parse row: %w", scanErr) + } + row := orderedmap.Row{} + for i, name := range cols { + row.Add(name, rawValues[i]) + } + out = append(out, row) + } + } + + // Check for errors from iterating over rows or from the query execution itself. + // results.Close() is handled by defer. + if err := results.Err(); err != nil { + return nil, fmt.Errorf("errors encountered during query execution or row processing: %w", err) + } + + return out, nil +} + func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) diff --git a/internal/sources/cloudsqlmysql/cloud_sql_mysql.go b/internal/sources/cloudsqlmysql/cloud_sql_mysql.go index 797985454b..759f00af7d 100644 --- a/internal/sources/cloudsqlmysql/cloud_sql_mysql.go +++ b/internal/sources/cloudsqlmysql/cloud_sql_mysql.go @@ -24,7 +24,9 @@ import ( "cloud.google.com/go/cloudsqlconn/mysql/mysql" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "go.opentelemetry.io/otel/trace" ) @@ -100,6 +102,60 @@ func (s *Source) MySQLPool() *sql.DB { return s.Pool } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + results, err := s.MySQLPool().QueryContext(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + cols, err := results.Columns() + if err != nil { + return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) + } + + // create an array of values for each column, which can be re-used to scan each row + rawValues := make([]any, len(cols)) + values := make([]any, len(cols)) + for i := range rawValues { + values[i] = &rawValues[i] + } + + colTypes, err := results.ColumnTypes() + if err != nil { + return nil, fmt.Errorf("unable to get column types: %w", err) + } + + var out []any + for results.Next() { + err := results.Scan(values...) + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + row := orderedmap.Row{} + for i, name := range cols { + val := rawValues[i] + if val == nil { + row.Add(name, nil) + continue + } + + convertedValue, err := mysqlcommon.ConvertToType(colTypes[i], val) + if err != nil { + return nil, fmt.Errorf("errors encountered when converting values: %w", err) + } + row.Add(name, convertedValue) + } + out = append(out, row) + } + + if err := results.Err(); err != nil { + return nil, fmt.Errorf("errors encountered during row iteration: %w", err) + } + + return out, nil +} + func getConnectionConfig(ctx context.Context, user, pass string) (string, string, bool, error) { useIAM := true diff --git a/internal/sources/mssql/mssql.go b/internal/sources/mssql/mssql.go index 39a37bf5a0..688ccf18c4 100644 --- a/internal/sources/mssql/mssql.go +++ b/internal/sources/mssql/mssql.go @@ -23,6 +23,7 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" _ "github.com/microsoft/go-mssqldb" "go.opentelemetry.io/otel/trace" ) @@ -104,6 +105,48 @@ func (s *Source) MSSQLDB() *sql.DB { return s.Db } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + results, err := s.MSSQLDB().QueryContext(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + cols, err := results.Columns() + // If Columns() errors, it might be a DDL/DML without an OUTPUT clause. + // We proceed, and results.Err() will catch actual query execution errors. + // 'out' will remain nil if cols is empty or err is not nil here. + var out []any + if err == nil && len(cols) > 0 { + // create an array of values for each column, which can be re-used to scan each row + rawValues := make([]any, len(cols)) + values := make([]any, len(cols)) + for i := range rawValues { + values[i] = &rawValues[i] + } + + for results.Next() { + scanErr := results.Scan(values...) + if scanErr != nil { + return nil, fmt.Errorf("unable to parse row: %w", scanErr) + } + row := orderedmap.Row{} + for i, name := range cols { + row.Add(name, rawValues[i]) + } + out = append(out, row) + } + } + + // Check for errors from iterating over rows or from the query execution itself. + // results.Close() is handled by defer. + if err := results.Err(); err != nil { + return nil, fmt.Errorf("errors encountered during query execution or row processing: %w", err) + } + + return out, nil +} + func initMssqlConnection( ctx context.Context, tracer trace.Tracer, diff --git a/internal/sources/mysql/mysql.go b/internal/sources/mysql/mysql.go index 13f4b2a3d9..b456ec9a3f 100644 --- a/internal/sources/mysql/mysql.go +++ b/internal/sources/mysql/mysql.go @@ -24,7 +24,9 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "go.opentelemetry.io/otel/trace" ) @@ -100,6 +102,60 @@ func (s *Source) MySQLPool() *sql.DB { return s.Pool } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + results, err := s.MySQLPool().QueryContext(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + cols, err := results.Columns() + if err != nil { + return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) + } + + // create an array of values for each column, which can be re-used to scan each row + rawValues := make([]any, len(cols)) + values := make([]any, len(cols)) + for i := range rawValues { + values[i] = &rawValues[i] + } + + colTypes, err := results.ColumnTypes() + if err != nil { + return nil, fmt.Errorf("unable to get column types: %w", err) + } + + var out []any + for results.Next() { + err := results.Scan(values...) + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + row := orderedmap.Row{} + for i, name := range cols { + val := rawValues[i] + if val == nil { + row.Add(name, nil) + continue + } + + convertedValue, err := mysqlcommon.ConvertToType(colTypes[i], val) + if err != nil { + return nil, fmt.Errorf("errors encountered when converting values: %w", err) + } + row.Add(name, convertedValue) + } + out = append(out, row) + } + + if err := results.Err(); err != nil { + return nil, fmt.Errorf("errors encountered during row iteration: %w", err) + } + + return out, nil +} + func initMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string, queryParams map[string]string) (*sql.DB, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) diff --git a/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go b/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go index ddfbdb089e..8963544d41 100644 --- a/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go +++ b/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go @@ -23,7 +23,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" - "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -45,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MSSQLDB() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -106,47 +106,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error getting logger: %s", err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - - results, err := source.MSSQLDB().QueryContext(ctx, sql) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - cols, err := results.Columns() - // If Columns() errors, it might be a DDL/DML without an OUTPUT clause. - // We proceed, and results.Err() will catch actual query execution errors. - // 'out' will remain nil if cols is empty or err is not nil here. - - var out []any - if err == nil && len(cols) > 0 { - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - for results.Next() { - scanErr := results.Scan(values...) - if scanErr != nil { - return nil, fmt.Errorf("unable to parse row: %w", scanErr) - } - row := orderedmap.Row{} - for i, name := range cols { - row.Add(name, rawValues[i]) - } - out = append(out, row) - } - } - - // Check for errors from iterating over rows or from the query execution itself. - // results.Close() is handled by defer. - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during query execution or row processing: %w", err) - } - - return out, nil + return source.RunSQL(ctx, sql, nil) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mssql/mssqllisttables/mssqllisttables.go b/internal/tools/mssql/mssqllisttables/mssqllisttables.go index 29fbea4498..633f43dee7 100644 --- a/internal/tools/mssql/mssqllisttables/mssqllisttables.go +++ b/internal/tools/mssql/mssqllisttables/mssqllisttables.go @@ -292,6 +292,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MSSQLDB() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -354,44 +355,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para sql.Named("table_names", paramsMap["table_names"]), sql.Named("output_format", outputFormat), } - - rows, err := source.MSSQLDB().QueryContext(ctx, listTablesStatement, namedArgs...) + resp, err := source.RunSQL(ctx, listTablesStatement, namedArgs) if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) + return nil, err } - defer rows.Close() - - cols, err := rows.Columns() - if err != nil { - return nil, fmt.Errorf("unable to fetch column names: %w", err) + // if there's no results, return empty list instead of null + resSlice, ok := resp.([]any) + if !ok || len(resSlice) == 0 { + return []any{}, nil } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - out := []any{} - for rows.Next() { - err = rows.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - vMap[name] = rawValues[i] - } - out = append(out, vMap) - } - - // Check if error occurred during iteration - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return resp, err } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mssql/mssqlsql/mssqlsql.go b/internal/tools/mssql/mssqlsql/mssqlsql.go index 0e621b7417..1b97a889bf 100644 --- a/internal/tools/mssql/mssqlsql/mssqlsql.go +++ b/internal/tools/mssql/mssqlsql/mssqlsql.go @@ -44,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MSSQLDB() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -121,47 +122,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para namedArgs = append(namedArgs, value) } } - - rows, err := source.MSSQLDB().QueryContext(ctx, newStatement, namedArgs...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - cols, err := rows.Columns() - if err != nil { - return nil, fmt.Errorf("unable to fetch column types: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - var out []any - for rows.Next() { - err = rows.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - vMap[name] = rawValues[i] - } - out = append(out, vMap) - } - err = rows.Close() - if err != nil { - return nil, fmt.Errorf("unable to close rows: %w", err) - } - - // Check if error occurred during iteration - if err := rows.Err(); err != nil { - return nil, err - } - - return out, nil + return source.RunSQL(ctx, newStatement, namedArgs) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go b/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go index 5198602d70..1f2a5bdee3 100644 --- a/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go +++ b/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go @@ -22,9 +22,7 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" - "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -46,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MySQLPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -107,58 +106,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error getting logger: %s", err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - - results, err := source.MySQLPool().QueryContext(ctx, sql) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - row := orderedmap.Row{} - for i, name := range cols { - val := rawValues[i] - if val == nil { - row.Add(name, nil) - continue - } - - convertedValue, err := mysqlcommon.ConvertToType(colTypes[i], val) - if err != nil { - return nil, fmt.Errorf("errors encountered when converting values: %w", err) - } - row.Add(name, convertedValue) - } - out = append(out, row) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return source.RunSQL(ctx, sql, nil) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go b/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go index 3458a6ed83..04ab5c23e0 100644 --- a/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go +++ b/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go @@ -24,6 +24,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -45,6 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MySQLPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -108,30 +110,27 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) query := fmt.Sprintf("EXPLAIN FORMAT=JSON %s", sql) - results, err := source.MySQLPool().QueryContext(ctx, query) + result, err := source.RunSQL(ctx, query, nil) if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) + return nil, err } - defer results.Close() - - var plan string - if results.Next() { - if err := results.Scan(&plan); err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - } else { + // extract and return only the query plan object + resSlice, ok := result.([]any) + if !ok || len(resSlice) == 0 { return nil, fmt.Errorf("no query plan returned") } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) + row, ok := resSlice[0].(orderedmap.Row) + if !ok || len(row.Columns) == 0 { + return nil, fmt.Errorf("no query plan returned in row") } - - var out any + plan, ok := row.Columns[0].Value.(string) + if !ok { + return nil, fmt.Errorf("unable to convert plan object to string") + } + var out map[string]any if err := json.Unmarshal([]byte(plan), &out); err != nil { return nil, fmt.Errorf("failed to unmarshal query plan json: %w", err) } - return out, nil } diff --git a/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go b/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go index 323d582d32..6124115c78 100644 --- a/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go +++ b/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go @@ -24,7 +24,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql" "github.com/googleapis/genai-toolbox/internal/sources/mysql" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -109,6 +108,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MySQLPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -146,14 +146,13 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) var statement string sourceKind := rawS.SourceKind() - switch sourceKind { case mysql.SourceKind: statement = listActiveQueriesStatementMySQL case cloudsqlmysql.SourceKind: statement = listActiveQueriesStatementCloudSQLMySQL default: - return nil, fmt.Errorf("unsupported source kind kind: %q", sourceKind) + return nil, fmt.Errorf("unsupported source kind: %s", cfg.Source) } // finish tool setup t := Tool{ @@ -200,57 +199,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error getting logger: %s", err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, t.statement)) - - results, err := source.MySQLPool().QueryContext(ctx, t.statement, duration, duration, limit) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - if val == nil { - vMap[name] = nil - continue - } - - vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val) - if err != nil { - return nil, fmt.Errorf("errors encountered when converting values: %w", err) - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return source.RunSQL(ctx, t.statement, []any{duration, duration, limit}) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go b/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go index a0bc1b8f66..28cce1bc54 100644 --- a/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go +++ b/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -67,6 +66,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MySQLPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -144,57 +144,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error getting logger: %s", err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, listTableFragmentationStatement)) - - results, err := source.MySQLPool().QueryContext(ctx, listTableFragmentationStatement, table_schema, table_schema, table_name, table_name, data_free_threshold_bytes, limit) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - if val == nil { - vMap[name] = nil - continue - } - - vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val) - if err != nil { - return nil, fmt.Errorf("errors encountered when converting values: %w", err) - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + sliceParams := []any{table_schema, table_schema, table_name, table_name, data_free_threshold_bytes, limit} + return source.RunSQL(ctx, listTableFragmentationStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqllisttables/mysqllisttables.go b/internal/tools/mysql/mysqllisttables/mysqllisttables.go index 66928b75fa..f8e0c1dced 100644 --- a/internal/tools/mysql/mysqllisttables/mysqllisttables.go +++ b/internal/tools/mysql/mysqllisttables/mysqllisttables.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -197,6 +196,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MySQLPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -259,57 +259,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if outputFormat != "simple" && outputFormat != "detailed" { return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat) } - - results, err := source.MySQLPool().QueryContext(ctx, listTablesStatement, tableNames, outputFormat) + resp, err := source.RunSQL(ctx, listTablesStatement, []any{tableNames, outputFormat}) if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) + return nil, err } - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) + // if there's no results, return empty list instead of null + resSlice, ok := resp.([]any) + if !ok || len(resSlice) == 0 { + return []any{}, nil } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - defer results.Close() - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - out := []any{} - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - if val == nil { - vMap[name] = nil - continue - } - - vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val) - if err != nil { - return nil, fmt.Errorf("errors encountered when converting values: %w", err) - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return resp, err } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go b/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go index 522b180acd..e19e14a33d 100644 --- a/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go +++ b/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -68,6 +67,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MySQLPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -135,57 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error getting logger: %s", err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, listTablesMissingUniqueIndexesStatement)) - - results, err := source.MySQLPool().QueryContext(ctx, listTablesMissingUniqueIndexesStatement, table_schema, table_schema, limit) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - if val == nil { - vMap[name] = nil - continue - } - - vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val) - if err != nil { - return nil, fmt.Errorf("errors encountered when converting values: %w", err) - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return source.RunSQL(ctx, listTablesMissingUniqueIndexesStatement, []any{table_schema, table_schema, limit}) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqlsql/mysqlsql.go b/internal/tools/mysql/mysqlsql/mysqlsql.go index edf5f65db1..f89dde648b 100644 --- a/internal/tools/mysql/mysqlsql/mysqlsql.go +++ b/internal/tools/mysql/mysqlsql/mysqlsql.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -44,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MySQLPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -110,56 +110,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := source.MySQLPool().QueryContext(ctx, newStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - defer results.Close() - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - if val == nil { - vMap[name] = nil - continue - } - - vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val) - if err != nil { - return nil, fmt.Errorf("errors encountered when converting values: %w", err) - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return source.RunSQL(ctx, newStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {