mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-07 22:54:06 -05:00
refactor: move source implementation in Invoke() function to Source (#2237)
Move source-related queries from `Invoke()` function into Source. The following sources are updated in this PR: * spanner * sqlite * tidb * trino * valkey * yugabytedb This is an effort to generalizing tools to work with any Source that implements a specific interface. This will provide a better segregation of the roles for Tools vs Source. Tool's role will be limited to the following: * Resolve any pre-implementation steps or parameters (e.g. template parameters) * Retrieving Source * Calling the source's implementation
This commit is contained in:
@@ -16,13 +16,16 @@ package spanner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"cloud.google.com/go/spanner"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/orderedmap"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/api/iterator"
|
||||
)
|
||||
|
||||
const SourceKind string = "spanner"
|
||||
@@ -93,6 +96,79 @@ func (s *Source) DatabaseDialect() string {
|
||||
return s.Dialect.String()
|
||||
}
|
||||
|
||||
// processRows iterates over the spanner.RowIterator and converts each row to a map[string]any.
|
||||
func processRows(iter *spanner.RowIterator) ([]any, error) {
|
||||
var out []any
|
||||
defer iter.Stop()
|
||||
|
||||
for {
|
||||
row, err := iter.Next()
|
||||
if err == iterator.Done {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
|
||||
rowMap := orderedmap.Row{}
|
||||
cols := row.ColumnNames()
|
||||
for i, c := range cols {
|
||||
if c == "object_details" { // for list graphs or list tables
|
||||
val := row.ColumnValue(i)
|
||||
if val == nil { // ColumnValue returns the Cloud Spanner Value of column i, or nil for invalid column.
|
||||
rowMap.Add(c, nil)
|
||||
} else {
|
||||
jsonString, ok := val.AsInterface().(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("column 'object_details' is not a string, but %T", val.AsInterface())
|
||||
}
|
||||
var details map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonString), &details); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal JSON: %w", err)
|
||||
}
|
||||
rowMap.Add(c, details)
|
||||
}
|
||||
} else {
|
||||
rowMap.Add(c, row.ColumnValue(i))
|
||||
}
|
||||
}
|
||||
out = append(out, rowMap)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Source) RunSQL(ctx context.Context, readOnly bool, statement string, params map[string]any) (any, error) {
|
||||
var results []any
|
||||
var err error
|
||||
var opErr error
|
||||
stmt := spanner.Statement{
|
||||
SQL: statement,
|
||||
}
|
||||
if params != nil {
|
||||
stmt.Params = params
|
||||
}
|
||||
|
||||
if readOnly {
|
||||
iter := s.SpannerClient().Single().Query(ctx, stmt)
|
||||
results, opErr = processRows(iter)
|
||||
} else {
|
||||
_, opErr = s.SpannerClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
|
||||
iter := txn.Query(ctx, stmt)
|
||||
results, err = processRows(iter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if opErr != nil {
|
||||
return nil, fmt.Errorf("unable to execute client: %w", opErr)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func initSpannerClient(ctx context.Context, tracer trace.Tracer, name, project, instance, dbname string) (*spanner.Client, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
|
||||
@@ -17,10 +17,12 @@ package sqlite
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/orderedmap"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
_ "modernc.org/sqlite" // Pure Go SQLite driver
|
||||
)
|
||||
@@ -91,6 +93,66 @@ func (s *Source) SQLiteDB() *sql.DB {
|
||||
return s.Db
|
||||
}
|
||||
|
||||
func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) {
|
||||
// Execute the SQL query with parameters
|
||||
rows, err := s.SQLiteDB().QueryContext(ctx, statement, params...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// Get column names
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get column names: %w", err)
|
||||
}
|
||||
|
||||
// The sqlite driver does not support ColumnTypes, so we can't get the
|
||||
// underlying database type of the columns. We'll have to rely on the
|
||||
// generic `any` type and then handle the JSON data separately.
|
||||
rawValues := make([]any, len(cols))
|
||||
values := make([]any, len(cols))
|
||||
for i := range rawValues {
|
||||
values[i] = &rawValues[i]
|
||||
}
|
||||
|
||||
// Prepare the result slice
|
||||
var out []any
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(values...); err != nil {
|
||||
return nil, fmt.Errorf("unable to scan row: %w", err)
|
||||
}
|
||||
|
||||
// Create a map for this row
|
||||
row := orderedmap.Row{}
|
||||
for i, name := range cols {
|
||||
val := rawValues[i]
|
||||
// Handle nil values
|
||||
if val == nil {
|
||||
row.Add(name, nil)
|
||||
continue
|
||||
}
|
||||
// Handle JSON data
|
||||
if jsonString, ok := val.(string); ok {
|
||||
var unmarshaledData any
|
||||
if json.Unmarshal([]byte(jsonString), &unmarshaledData) == nil {
|
||||
row.Add(name, unmarshaledData)
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Store the value in the map
|
||||
row.Add(name, val)
|
||||
}
|
||||
out = append(out, row)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating rows: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func initSQLiteConnection(ctx context.Context, tracer trace.Tracer, name, dbPath string) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
|
||||
@@ -17,6 +17,7 @@ package tidb
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
|
||||
@@ -104,6 +105,79 @@ func (s *Source) TiDBPool() *sql.DB {
|
||||
return s.Pool
|
||||
}
|
||||
|
||||
func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) {
|
||||
results, err := s.TiDBPool().QueryContext(ctx, statement, params...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
cols, err := results.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
|
||||
}
|
||||
|
||||
// create an array of values for each column, which can be re-used to scan each row
|
||||
rawValues := make([]any, len(cols))
|
||||
values := make([]any, len(cols))
|
||||
for i := range rawValues {
|
||||
values[i] = &rawValues[i]
|
||||
}
|
||||
defer results.Close()
|
||||
|
||||
colTypes, err := results.ColumnTypes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get column types: %w", err)
|
||||
}
|
||||
|
||||
var out []any
|
||||
for results.Next() {
|
||||
err := results.Scan(values...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
vMap := make(map[string]any)
|
||||
for i, name := range cols {
|
||||
val := rawValues[i]
|
||||
if val == nil {
|
||||
vMap[name] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
// mysql driver return []uint8 type for "TEXT", "VARCHAR", and "NVARCHAR"
|
||||
// we'll need to cast it back to string
|
||||
switch colTypes[i].DatabaseTypeName() {
|
||||
case "JSON":
|
||||
// unmarshal JSON data before storing to prevent double
|
||||
// marshaling
|
||||
byteVal, ok := val.([]byte)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected []byte for JSON column, but got %T", val)
|
||||
}
|
||||
var unmarshaledData any
|
||||
if err := json.Unmarshal(byteVal, &unmarshaledData); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal json data %s", val)
|
||||
}
|
||||
vMap[name] = unmarshaledData
|
||||
case "TEXT", "VARCHAR", "NVARCHAR":
|
||||
byteVal, ok := val.([]byte)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected []byte for text-like column, but got %T", val)
|
||||
}
|
||||
vMap[name] = string(byteVal)
|
||||
default:
|
||||
vMap[name] = val
|
||||
}
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
if err := results.Err(); err != nil {
|
||||
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func IsTiDBCloudHost(host string) bool {
|
||||
pattern := `gateway\d{2}\.(.+)\.(prod|dev|staging)\.(.+)\.tidbcloud\.com`
|
||||
match, err := regexp.MatchString(pattern, host)
|
||||
|
||||
@@ -102,6 +102,56 @@ func (s *Source) TrinoDB() *sql.DB {
|
||||
return s.Pool
|
||||
}
|
||||
|
||||
func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) {
|
||||
results, err := s.TrinoDB().QueryContext(ctx, statement, params...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
defer results.Close()
|
||||
|
||||
cols, err := results.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to retrieve column names: %w", err)
|
||||
}
|
||||
|
||||
// create an array of values for each column, which can be re-used to scan each row
|
||||
rawValues := make([]any, len(cols))
|
||||
values := make([]any, len(cols))
|
||||
for i := range rawValues {
|
||||
values[i] = &rawValues[i]
|
||||
}
|
||||
|
||||
var out []any
|
||||
for results.Next() {
|
||||
err := results.Scan(values...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
vMap := make(map[string]any)
|
||||
for i, name := range cols {
|
||||
val := rawValues[i]
|
||||
if val == nil {
|
||||
vMap[name] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert byte arrays to strings for text fields
|
||||
if b, ok := val.([]byte); ok {
|
||||
vMap[name] = string(b)
|
||||
} else {
|
||||
vMap[name] = val
|
||||
}
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
if err := results.Err(); err != nil {
|
||||
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func initTrinoConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, password, catalog, schema, queryTimeout, accessToken string, kerberosEnabled, sslEnabled bool) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
|
||||
@@ -125,3 +125,37 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
func (s *Source) ValkeyClient() valkey.Client {
|
||||
return s.Client
|
||||
}
|
||||
|
||||
func (s *Source) RunCommand(ctx context.Context, cmds [][]string) (any, error) {
|
||||
// Build commands
|
||||
builtCmds := make(valkey.Commands, len(cmds))
|
||||
|
||||
for i, cmd := range cmds {
|
||||
builtCmds[i] = s.ValkeyClient().B().Arbitrary(cmd...).Build()
|
||||
}
|
||||
|
||||
if len(builtCmds) == 0 {
|
||||
return nil, fmt.Errorf("no valid commands were built to execute")
|
||||
}
|
||||
|
||||
// Execute commands
|
||||
responses := s.ValkeyClient().DoMulti(ctx, builtCmds...)
|
||||
|
||||
// Parse responses
|
||||
out := make([]any, len(cmds))
|
||||
for i, resp := range responses {
|
||||
if err := resp.Error(); err != nil {
|
||||
// Store error message in the output for this command
|
||||
out[i] = fmt.Sprintf("error from executing command at index %d: %s", i, err)
|
||||
continue
|
||||
}
|
||||
val, err := resp.ToAny()
|
||||
if err != nil {
|
||||
out[i] = fmt.Sprintf("error parsing response: %s", err)
|
||||
continue
|
||||
}
|
||||
out[i] = val
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
@@ -99,6 +99,35 @@ func (s *Source) YugabyteDBPool() *pgxpool.Pool {
|
||||
return s.Pool
|
||||
}
|
||||
|
||||
func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) {
|
||||
results, err := s.YugabyteDBPool().Query(ctx, statement, params...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
fields := results.FieldDescriptions()
|
||||
|
||||
var out []any
|
||||
for results.Next() {
|
||||
v, err := results.Values()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
vMap := make(map[string]any)
|
||||
for i, f := range fields {
|
||||
vMap[f.Name] = v[i]
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
// this will catch actual query execution errors
|
||||
if err := results.Err(); err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func initYugabyteDBConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, loadBalance, topologyKeys, refreshInterval, explicitFallback, failedHostTTL string) (*pgxpool.Pool, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
|
||||
@@ -23,9 +23,7 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/orderedmap"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/iterator"
|
||||
)
|
||||
|
||||
const kind string = "spanner-execute-sql"
|
||||
@@ -47,6 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
SpannerClient() *spanner.Client
|
||||
DatabaseDialect() string
|
||||
RunSQL(context.Context, bool, string, map[string]any) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -91,30 +90,6 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
// processRows iterates over the spanner.RowIterator and converts each row to a map[string]any.
|
||||
func processRows(iter *spanner.RowIterator) ([]any, error) {
|
||||
var out []any
|
||||
defer iter.Stop()
|
||||
|
||||
for {
|
||||
row, err := iter.Next()
|
||||
if err == iterator.Done {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
|
||||
rowMap := orderedmap.Row{}
|
||||
cols := row.ColumnNames()
|
||||
for i, c := range cols {
|
||||
rowMap.Add(c, row.ColumnValue(i))
|
||||
}
|
||||
out = append(out, rowMap)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
@@ -133,31 +108,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("error getting logger: %s", err)
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql))
|
||||
|
||||
var results []any
|
||||
var opErr error
|
||||
stmt := spanner.Statement{SQL: sql}
|
||||
|
||||
if t.ReadOnly {
|
||||
iter := source.SpannerClient().Single().Query(ctx, stmt)
|
||||
results, opErr = processRows(iter)
|
||||
} else {
|
||||
_, opErr = source.SpannerClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
|
||||
var err error
|
||||
iter := txn.Query(ctx, stmt)
|
||||
results, err = processRows(iter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if opErr != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", opErr)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
return source.RunSQL(ctx, t.ReadOnly, sql, nil)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -16,7 +16,6 @@ package spannerlistgraphs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -25,7 +24,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/iterator"
|
||||
)
|
||||
|
||||
const kind string = "spanner-list-graphs"
|
||||
@@ -47,6 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
SpannerClient() *spanner.Client
|
||||
DatabaseDialect() string
|
||||
RunSQL(context.Context, bool, string, map[string]any) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -105,39 +104,6 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
// processRows iterates over the spanner.RowIterator and converts each row to a map[string]any.
|
||||
func processRows(iter *spanner.RowIterator) ([]any, error) {
|
||||
var out []any
|
||||
defer iter.Stop()
|
||||
|
||||
for {
|
||||
row, err := iter.Next()
|
||||
if err == iterator.Done {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
|
||||
vMap := make(map[string]any)
|
||||
cols := row.ColumnNames()
|
||||
for i, c := range cols {
|
||||
if c == "object_details" {
|
||||
jsonString := row.ColumnValue(i).AsInterface().(string)
|
||||
var details map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonString), &details); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal JSON: %w", err)
|
||||
}
|
||||
vMap[c] = details
|
||||
} else {
|
||||
vMap[c] = row.ColumnValue(i)
|
||||
}
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
@@ -161,20 +127,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
"graph_names": graphNames,
|
||||
"output_format": outputFormat,
|
||||
}
|
||||
|
||||
stmt := spanner.Statement{
|
||||
SQL: googleSQLStatement,
|
||||
Params: stmtParams,
|
||||
}
|
||||
|
||||
// Execute the query (read-only)
|
||||
iter := source.SpannerClient().Single().Query(ctx, stmt)
|
||||
results, err := processRows(iter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
return source.RunSQL(ctx, true, googleSQLStatement, stmtParams)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -16,7 +16,6 @@ package spannerlisttables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -25,7 +24,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/iterator"
|
||||
)
|
||||
|
||||
const kind string = "spanner-list-tables"
|
||||
@@ -47,6 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
SpannerClient() *spanner.Client
|
||||
DatabaseDialect() string
|
||||
RunSQL(context.Context, bool, string, map[string]any) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -105,41 +104,8 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
// processRows iterates over the spanner.RowIterator and converts each row to a map[string]any.
|
||||
func processRows(iter *spanner.RowIterator) ([]any, error) {
|
||||
out := []any{}
|
||||
defer iter.Stop()
|
||||
|
||||
for {
|
||||
row, err := iter.Next()
|
||||
if err == iterator.Done {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
|
||||
vMap := make(map[string]any)
|
||||
cols := row.ColumnNames()
|
||||
for i, c := range cols {
|
||||
if c == "object_details" {
|
||||
jsonString := row.ColumnValue(i).AsInterface().(string)
|
||||
var details map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonString), &details); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal JSON: %w", err)
|
||||
}
|
||||
vMap[c] = details
|
||||
} else {
|
||||
vMap[c] = row.ColumnValue(i)
|
||||
}
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t Tool) getStatement(source compatibleSource) string {
|
||||
switch strings.ToLower(source.DatabaseDialect()) {
|
||||
func getStatement(dialect string) string {
|
||||
switch strings.ToLower(dialect) {
|
||||
case "postgresql":
|
||||
return postgresqlStatement
|
||||
case "googlesql":
|
||||
@@ -159,7 +125,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
// Get the appropriate SQL statement based on dialect
|
||||
statement := t.getStatement(source)
|
||||
statement := getStatement(source.DatabaseDialect())
|
||||
|
||||
// Prepare parameters based on dialect
|
||||
var stmtParams map[string]interface{}
|
||||
@@ -177,7 +143,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
"p1": tableNames,
|
||||
"p2": outputFormat,
|
||||
}
|
||||
|
||||
case "googlesql":
|
||||
// GoogleSQL uses named parameters (@table_names, @output_format)
|
||||
stmtParams = map[string]interface{}{
|
||||
@@ -188,19 +153,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("unsupported dialect: %s", source.DatabaseDialect())
|
||||
}
|
||||
|
||||
stmt := spanner.Statement{
|
||||
SQL: statement,
|
||||
Params: stmtParams,
|
||||
}
|
||||
|
||||
// Execute the query (read-only)
|
||||
iter := source.SpannerClient().Single().Query(ctx, stmt)
|
||||
results, err := processRows(iter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
return source.RunSQL(ctx, true, statement, stmtParams)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/iterator"
|
||||
)
|
||||
|
||||
const kind string = "spanner-sql"
|
||||
@@ -46,6 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
SpannerClient() *spanner.Client
|
||||
DatabaseDialect() string
|
||||
RunSQL(context.Context, bool, string, map[string]any) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -106,30 +106,6 @@ func getMapParams(params parameters.ParamValues, dialect string) (map[string]int
|
||||
}
|
||||
}
|
||||
|
||||
// processRows iterates over the spanner.RowIterator and converts each row to a map[string]any.
|
||||
func processRows(iter *spanner.RowIterator) ([]any, error) {
|
||||
var out []any
|
||||
defer iter.Stop()
|
||||
|
||||
for {
|
||||
row, err := iter.Next()
|
||||
if err == iterator.Done {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
|
||||
vMap := make(map[string]any)
|
||||
cols := row.ColumnNames()
|
||||
for i, c := range cols {
|
||||
vMap[c] = row.ColumnValue(i)
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
@@ -174,33 +150,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fail to get map params: %w", err)
|
||||
}
|
||||
|
||||
var results []any
|
||||
var opErr error
|
||||
stmt := spanner.Statement{
|
||||
SQL: newStatement,
|
||||
Params: mapParams,
|
||||
}
|
||||
|
||||
if t.ReadOnly {
|
||||
iter := source.SpannerClient().Single().Query(ctx, stmt)
|
||||
results, opErr = processRows(iter)
|
||||
} else {
|
||||
_, opErr = source.SpannerClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
|
||||
iter := txn.Query(ctx, stmt)
|
||||
results, err = processRows(iter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if opErr != nil {
|
||||
return nil, fmt.Errorf("unable to execute client: %w", opErr)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
return source.RunSQL(ctx, t.ReadOnly, newStatement, mapParams)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,14 +17,12 @@ package sqliteexecutesql
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/orderedmap"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
@@ -46,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
SQLiteDB() *sql.DB
|
||||
RunSQL(context.Context, string, []any) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -108,65 +107,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("error getting logger: %s", err)
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql))
|
||||
|
||||
results, err := source.SQLiteDB().QueryContext(ctx, sql)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
cols, err := results.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
|
||||
}
|
||||
|
||||
// The sqlite driver does not support ColumnTypes, so we can't get the
|
||||
// underlying database type of the columns. We'll have to rely on the
|
||||
// generic `any` type and then handle the JSON data separately.
|
||||
|
||||
// create an array of values for each column, which can be re-used to scan each row
|
||||
rawValues := make([]any, len(cols))
|
||||
values := make([]any, len(cols))
|
||||
for i := range rawValues {
|
||||
values[i] = &rawValues[i]
|
||||
}
|
||||
defer results.Close()
|
||||
|
||||
var out []any
|
||||
for results.Next() {
|
||||
err := results.Scan(values...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
row := orderedmap.Row{}
|
||||
for i, name := range cols {
|
||||
val := rawValues[i]
|
||||
if val == nil {
|
||||
row.Add(name, nil)
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle JSON data
|
||||
if jsonString, ok := val.(string); ok {
|
||||
var unmarshaledData any
|
||||
if json.Unmarshal([]byte(jsonString), &unmarshaledData) == nil {
|
||||
row.Add(name, unmarshaledData)
|
||||
continue
|
||||
}
|
||||
}
|
||||
row.Add(name, val)
|
||||
}
|
||||
out = append(out, row)
|
||||
}
|
||||
|
||||
if err := results.Err(); err != nil {
|
||||
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
||||
}
|
||||
|
||||
if len(out) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return out, nil
|
||||
return source.RunSQL(ctx, sql, nil)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,7 +17,6 @@ package sqlitesql
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
@@ -44,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
SQLiteDB() *sql.DB
|
||||
RunSQL(context.Context, string, []any) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -108,64 +108,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
|
||||
// Execute the SQL query with parameters
|
||||
rows, err := source.SQLiteDB().QueryContext(ctx, newStatement, newParams.AsSlice()...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// Get column names
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get column names: %w", err)
|
||||
}
|
||||
|
||||
// The sqlite driver does not support ColumnTypes, so we can't get the
|
||||
// underlying database type of the columns. We'll have to rely on the
|
||||
// generic `any` type and then handle the JSON data separately.
|
||||
rawValues := make([]any, len(cols))
|
||||
values := make([]any, len(cols))
|
||||
for i := range rawValues {
|
||||
values[i] = &rawValues[i]
|
||||
}
|
||||
|
||||
// Prepare the result slice
|
||||
var out []any
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(values...); err != nil {
|
||||
return nil, fmt.Errorf("unable to scan row: %w", err)
|
||||
}
|
||||
|
||||
// Create a map for this row
|
||||
vMap := make(map[string]any)
|
||||
for i, name := range cols {
|
||||
val := rawValues[i]
|
||||
// Handle nil values
|
||||
if val == nil {
|
||||
vMap[name] = nil
|
||||
continue
|
||||
}
|
||||
// Handle JSON data
|
||||
if jsonString, ok := val.(string); ok {
|
||||
var unmarshaledData any
|
||||
if json.Unmarshal([]byte(jsonString), &unmarshaledData) == nil {
|
||||
vMap[name] = unmarshaledData
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Store the value in the map
|
||||
vMap[name] = val
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating rows: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
return source.RunSQL(ctx, newStatement, newParams.AsSlice())
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -44,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
TiDBPool() *sql.DB
|
||||
RunSQL(context.Context, string, []any) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -105,61 +106,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("error getting logger: %s", err)
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql))
|
||||
|
||||
results, err := source.TiDBPool().QueryContext(ctx, sql)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
defer results.Close()
|
||||
|
||||
cols, err := results.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
|
||||
}
|
||||
|
||||
// create an array of values for each column, which can be re-used to scan each row
|
||||
rawValues := make([]any, len(cols))
|
||||
values := make([]any, len(cols))
|
||||
for i := range rawValues {
|
||||
values[i] = &rawValues[i]
|
||||
}
|
||||
|
||||
colTypes, err := results.ColumnTypes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get column types: %w", err)
|
||||
}
|
||||
|
||||
var out []any
|
||||
for results.Next() {
|
||||
err := results.Scan(values...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
vMap := make(map[string]any)
|
||||
for i, name := range cols {
|
||||
val := rawValues[i]
|
||||
if val == nil {
|
||||
vMap[name] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
// mysql driver return []uint8 type for "TEXT", "VARCHAR", and "NVARCHAR"
|
||||
// we'll need to cast it back to string
|
||||
switch colTypes[i].DatabaseTypeName() {
|
||||
case "TEXT", "VARCHAR", "NVARCHAR":
|
||||
vMap[name] = string(val.([]byte))
|
||||
default:
|
||||
vMap[name] = val
|
||||
}
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
if err := results.Err(); err != nil {
|
||||
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
return source.RunSQL(ctx, sql, nil)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,7 +17,6 @@ package tidbsql
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
@@ -44,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
TiDBPool() *sql.DB
|
||||
RunSQL(context.Context, string, []any) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -110,68 +110,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
sliceParams := newParams.AsSlice()
|
||||
results, err := source.TiDBPool().QueryContext(ctx, newStatement, sliceParams...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
cols, err := results.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
|
||||
}
|
||||
|
||||
// create an array of values for each column, which can be re-used to scan each row
|
||||
rawValues := make([]any, len(cols))
|
||||
values := make([]any, len(cols))
|
||||
for i := range rawValues {
|
||||
values[i] = &rawValues[i]
|
||||
}
|
||||
defer results.Close()
|
||||
|
||||
colTypes, err := results.ColumnTypes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get column types: %w", err)
|
||||
}
|
||||
|
||||
var out []any
|
||||
for results.Next() {
|
||||
err := results.Scan(values...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
vMap := make(map[string]any)
|
||||
for i, name := range cols {
|
||||
val := rawValues[i]
|
||||
if val == nil {
|
||||
vMap[name] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
// mysql driver return []uint8 type for "TEXT", "VARCHAR", and "NVARCHAR"
|
||||
// we'll need to cast it back to string
|
||||
switch colTypes[i].DatabaseTypeName() {
|
||||
case "JSON":
|
||||
// unmarshal JSON data before storing to prevent double marshaling
|
||||
var unmarshaledData any
|
||||
err := json.Unmarshal(val.([]byte), &unmarshaledData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal json data %s", val)
|
||||
}
|
||||
vMap[name] = unmarshaledData
|
||||
case "TEXT", "VARCHAR", "NVARCHAR":
|
||||
vMap[name] = string(val.([]byte))
|
||||
default:
|
||||
vMap[name] = val
|
||||
}
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
if err := results.Err(); err != nil {
|
||||
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
return source.RunSQL(ctx, newStatement, sliceParams)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -43,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
TrinoDB() *sql.DB
|
||||
RunSQL(context.Context, string, []any) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -97,54 +98,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast sql parameter: %v", sliceParams[0])
|
||||
}
|
||||
|
||||
results, err := source.TrinoDB().QueryContext(ctx, sql)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
defer results.Close()
|
||||
|
||||
cols, err := results.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to retrieve column names: %w", err)
|
||||
}
|
||||
|
||||
// create an array of values for each column, which can be re-used to scan each row
|
||||
rawValues := make([]any, len(cols))
|
||||
values := make([]any, len(cols))
|
||||
for i := range rawValues {
|
||||
values[i] = &rawValues[i]
|
||||
}
|
||||
|
||||
var out []any
|
||||
for results.Next() {
|
||||
err := results.Scan(values...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
vMap := make(map[string]any)
|
||||
for i, name := range cols {
|
||||
val := rawValues[i]
|
||||
if val == nil {
|
||||
vMap[name] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert byte arrays to strings for text fields
|
||||
if b, ok := val.([]byte); ok {
|
||||
vMap[name] = string(b)
|
||||
} else {
|
||||
vMap[name] = val
|
||||
}
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
if err := results.Err(); err != nil {
|
||||
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
return source.RunSQL(ctx, sql, nil)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -43,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
TrinoDB() *sql.DB
|
||||
RunSQL(context.Context, string, []any) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -107,53 +108,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
sliceParams := newParams.AsSlice()
|
||||
results, err := source.TrinoDB().QueryContext(ctx, newStatement, sliceParams...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
defer results.Close()
|
||||
|
||||
cols, err := results.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to retrieve column names: %w", err)
|
||||
}
|
||||
|
||||
// create an array of values for each column, which can be re-used to scan each row
|
||||
rawValues := make([]any, len(cols))
|
||||
values := make([]any, len(cols))
|
||||
for i := range rawValues {
|
||||
values[i] = &rawValues[i]
|
||||
}
|
||||
|
||||
var out []any
|
||||
for results.Next() {
|
||||
err := results.Scan(values...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
vMap := make(map[string]any)
|
||||
for i, name := range cols {
|
||||
val := rawValues[i]
|
||||
if val == nil {
|
||||
vMap[name] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert byte arrays to strings for text fields
|
||||
if b, ok := val.([]byte); ok {
|
||||
vMap[name] = string(b)
|
||||
} else {
|
||||
vMap[name] = val
|
||||
}
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
if err := results.Err(); err != nil {
|
||||
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
return source.RunSQL(ctx, newStatement, sliceParams)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -42,6 +42,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
ValkeyClient() valkey.Client
|
||||
RunCommand(context.Context, [][]string) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -93,38 +94,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error replacing commands' parameters: %s", err)
|
||||
}
|
||||
|
||||
// Build commands
|
||||
builtCmds := make(valkey.Commands, len(commands))
|
||||
|
||||
for i, cmd := range commands {
|
||||
builtCmds[i] = source.ValkeyClient().B().Arbitrary(cmd...).Build()
|
||||
}
|
||||
|
||||
if len(builtCmds) == 0 {
|
||||
return nil, fmt.Errorf("no valid commands were built to execute")
|
||||
}
|
||||
|
||||
// Execute commands
|
||||
responses := source.ValkeyClient().DoMulti(ctx, builtCmds...)
|
||||
|
||||
// Parse responses
|
||||
out := make([]any, len(t.Commands))
|
||||
for i, resp := range responses {
|
||||
if err := resp.Error(); err != nil {
|
||||
// Add error from each command to `errSum`
|
||||
out[i] = fmt.Sprintf("error from executing command at index %d: %s", i, err)
|
||||
continue
|
||||
}
|
||||
val, err := resp.ToAny()
|
||||
if err != nil {
|
||||
out[i] = fmt.Sprintf("error parsing response: %s", err)
|
||||
continue
|
||||
}
|
||||
out[i] = val
|
||||
}
|
||||
|
||||
return out, nil
|
||||
return source.RunCommand(ctx, commands)
|
||||
}
|
||||
|
||||
// replaceCommandsParams is a helper function to replace parameters in the commands
|
||||
|
||||
@@ -43,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
YugabyteDBPool() *pgxpool.Pool
|
||||
RunSQL(context.Context, string, []any) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -108,32 +109,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
sliceParams := newParams.AsSlice()
|
||||
results, err := source.YugabyteDBPool().Query(ctx, newStatement, sliceParams...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
fields := results.FieldDescriptions()
|
||||
|
||||
var out []any
|
||||
for results.Next() {
|
||||
v, err := results.Values()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
vMap := make(map[string]any)
|
||||
for i, f := range fields {
|
||||
vMap[f.Name] = v[i]
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
// this will catch actual query execution errors
|
||||
if err := results.Err(); err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
return source.RunSQL(ctx, newStatement, sliceParams)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
Reference in New Issue
Block a user