From df2f6a9f0bcc86b1cb11401f217062f3f6ef9d4b Mon Sep 17 00:00:00 2001 From: Yuan Teoh <45984206+Yuan325@users.noreply.github.com> Date: Tue, 30 Dec 2025 14:45:48 -0800 Subject: [PATCH] refactor: move source implementation in Invoke() function to Source (#2234) Move source-related queries from `Invoke()` function into Source. The following sources are updated in this PR: * couchbase * dgraph * elasticsearch * firebird This is an effort to generalizing tools to work with any Source that implements a specific interface. This will provide a better segregation of the roles for Tools vs Source. Tool's role will be limited to the following: * Resolve any pre-implementation steps or parameters (e.g. template parameters) * Retrieving Source * Calling the source's implementation --- internal/sources/couchbase/couchbase.go | 23 +++ internal/sources/dgraph/dgraph.go | 27 ++- .../sources/elasticsearch/elasticsearch.go | 79 +++++++++ .../elasticsearch/elasticsearch_test.go | 153 +++++++++++++++++ internal/sources/firebird/firebird.go | 47 ++++++ internal/tools/couchbase/couchbase.go | 22 +-- internal/tools/dgraph/dgraph.go | 24 +-- .../elasticsearchesql/elasticsearchesql.go | 92 +---------- .../elasticsearchesql_test.go | 154 ------------------ .../firebirdexecutesql/firebirdexecutesql.go | 45 +---- .../tools/firebird/firebirdsql/firebirdsql.go | 47 +----- 11 files changed, 343 insertions(+), 370 deletions(-) diff --git a/internal/sources/couchbase/couchbase.go b/internal/sources/couchbase/couchbase.go index 422d9ab001..c273a47ec0 100644 --- a/internal/sources/couchbase/couchbase.go +++ b/internal/sources/couchbase/couchbase.go @@ -17,6 +17,7 @@ package couchbase import ( "context" "crypto/tls" + "encoding/json" "fmt" "os" @@ -24,6 +25,7 @@ import ( tlsutil "github.com/couchbase/tools-common/http/tls" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.opentelemetry.io/otel/trace" ) @@ -110,6 +112,27 @@ func (s *Source) CouchbaseQueryScanConsistency() uint { return s.QueryScanConsistency } +func (s *Source) RunSQL(statement string, params parameters.ParamValues) (any, error) { + results, err := s.CouchbaseScope().Query(statement, &gocb.QueryOptions{ + ScanConsistency: gocb.QueryScanConsistency(s.CouchbaseQueryScanConsistency()), + NamedParameters: params.AsMap(), + }) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + + var out []any + for results.Next() { + var result json.RawMessage + err := results.Row(&result) + if err != nil { + return nil, fmt.Errorf("error processing row: %w", err) + } + out = append(out, result) + } + return out, nil +} + func (r Config) createCouchbaseOptions() (gocb.ClusterOptions, error) { cbOpts := gocb.ClusterOptions{} diff --git a/internal/sources/dgraph/dgraph.go b/internal/sources/dgraph/dgraph.go index 24f8f8b20e..317779db38 100644 --- a/internal/sources/dgraph/dgraph.go +++ b/internal/sources/dgraph/dgraph.go @@ -26,6 +26,7 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.opentelemetry.io/otel/trace" ) @@ -114,6 +115,28 @@ func (s *Source) DgraphClient() *DgraphClient { return s.Client } +func (s *Source) RunSQL(statement string, params parameters.ParamValues, isQuery bool, timeout string) (any, error) { + paramsMap := params.AsMapWithDollarPrefix() + resp, err := s.DgraphClient().ExecuteQuery(statement, paramsMap, isQuery, timeout) + if err != nil { + return nil, err + } + + if err := checkError(resp); err != nil { + return nil, err + } + + var result struct { + Data map[string]interface{} `json:"data"` + } + + if err := json.Unmarshal(resp, &result); err != nil { + return nil, fmt.Errorf("error parsing JSON: %v", err) + } + + return result.Data, nil +} + func initDgraphHttpClient(ctx context.Context, tracer trace.Tracer, r Config) (*DgraphClient, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, r.Name) @@ -285,7 +308,7 @@ func (hc *DgraphClient) doLogin(creds map[string]interface{}) error { return err } - if err := CheckError(resp); err != nil { + if err := checkError(resp); err != nil { return err } @@ -370,7 +393,7 @@ func getUrl(baseUrl, resource string, params url.Values) (string, error) { return u.String(), nil } -func CheckError(resp []byte) error { +func checkError(resp []byte) error { var errResp struct { Errors []struct { Message string `json:"message"` diff --git a/internal/sources/elasticsearch/elasticsearch.go b/internal/sources/elasticsearch/elasticsearch.go index 2d7b788407..b5ec915c18 100644 --- a/internal/sources/elasticsearch/elasticsearch.go +++ b/internal/sources/elasticsearch/elasticsearch.go @@ -15,7 +15,9 @@ package elasticsearch import ( + "bytes" "context" + "encoding/json" "fmt" "net/http" @@ -149,3 +151,80 @@ func (s *Source) ToConfig() sources.SourceConfig { func (s *Source) ElasticsearchClient() EsClient { return s.Client } + +type EsqlColumn struct { + Name string `json:"name"` + Type string `json:"type"` +} + +type EsqlResult struct { + Columns []EsqlColumn `json:"columns"` + Values [][]any `json:"values"` +} + +func (s *Source) RunSQL(ctx context.Context, format, query string, params []map[string]any) (any, error) { + bodyStruct := struct { + Query string `json:"query"` + Params []map[string]any `json:"params,omitempty"` + }{ + Query: query, + Params: params, + } + body, err := json.Marshal(bodyStruct) + if err != nil { + return nil, fmt.Errorf("failed to marshal query body: %w", err) + } + + res, err := esapi.EsqlQueryRequest{ + Body: bytes.NewReader(body), + Format: format, + FilterPath: []string{"columns", "values"}, + Instrument: s.ElasticsearchClient().InstrumentationEnabled(), + }.Do(ctx, s.ElasticsearchClient()) + + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.IsError() { + // Try to extract error message from response + var esErr json.RawMessage + err = util.DecodeJSON(res.Body, &esErr) + if err != nil { + return nil, fmt.Errorf("elasticsearch error: status %s", res.Status()) + } + return esErr, nil + } + + var result EsqlResult + err = util.DecodeJSON(res.Body, &result) + if err != nil { + return nil, fmt.Errorf("failed to decode response body: %w", err) + } + + output := EsqlToMap(result) + + return output, nil +} + +// EsqlToMap converts the esqlResult to a slice of maps. +func EsqlToMap(result EsqlResult) []map[string]any { + output := make([]map[string]any, 0, len(result.Values)) + for _, value := range result.Values { + row := make(map[string]any) + if value == nil { + output = append(output, row) + continue + } + for i, col := range result.Columns { + if i < len(value) { + row[col.Name] = value[i] + } else { + row[col.Name] = nil + } + } + output = append(output, row) + } + return output +} diff --git a/internal/sources/elasticsearch/elasticsearch_test.go b/internal/sources/elasticsearch/elasticsearch_test.go index 6ea9d33dce..95d941edc4 100644 --- a/internal/sources/elasticsearch/elasticsearch_test.go +++ b/internal/sources/elasticsearch/elasticsearch_test.go @@ -15,6 +15,7 @@ package elasticsearch_test import ( + "reflect" "testing" yaml "github.com/goccy/go-yaml" @@ -64,3 +65,155 @@ func TestParseFromYamlElasticsearch(t *testing.T) { }) } } + +func TestTool_esqlToMap(t1 *testing.T) { + tests := []struct { + name string + result elasticsearch.EsqlResult + want []map[string]any + }{ + { + name: "simple case with two rows", + result: elasticsearch.EsqlResult{ + Columns: []elasticsearch.EsqlColumn{ + {Name: "first_name", Type: "text"}, + {Name: "last_name", Type: "text"}, + }, + Values: [][]any{ + {"John", "Doe"}, + {"Jane", "Smith"}, + }, + }, + want: []map[string]any{ + {"first_name": "John", "last_name": "Doe"}, + {"first_name": "Jane", "last_name": "Smith"}, + }, + }, + { + name: "different data types", + result: elasticsearch.EsqlResult{ + Columns: []elasticsearch.EsqlColumn{ + {Name: "id", Type: "integer"}, + {Name: "active", Type: "boolean"}, + {Name: "score", Type: "float"}, + }, + Values: [][]any{ + {1, true, 95.5}, + {2, false, 88.0}, + }, + }, + want: []map[string]any{ + {"id": 1, "active": true, "score": 95.5}, + {"id": 2, "active": false, "score": 88.0}, + }, + }, + { + name: "no rows", + result: elasticsearch.EsqlResult{ + Columns: []elasticsearch.EsqlColumn{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "text"}, + }, + Values: [][]any{}, + }, + want: []map[string]any{}, + }, + { + name: "null values", + result: elasticsearch.EsqlResult{ + Columns: []elasticsearch.EsqlColumn{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "text"}, + }, + Values: [][]any{ + {1, nil}, + {2, "Alice"}, + }, + }, + want: []map[string]any{ + {"id": 1, "name": nil}, + {"id": 2, "name": "Alice"}, + }, + }, + { + name: "missing values in a row", + result: elasticsearch.EsqlResult{ + Columns: []elasticsearch.EsqlColumn{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "text"}, + {Name: "age", Type: "integer"}, + }, + Values: [][]any{ + {1, "Bob"}, + {2, "Charlie", 30}, + }, + }, + want: []map[string]any{ + {"id": 1, "name": "Bob", "age": nil}, + {"id": 2, "name": "Charlie", "age": 30}, + }, + }, + { + name: "all null row", + result: elasticsearch.EsqlResult{ + Columns: []elasticsearch.EsqlColumn{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "text"}, + }, + Values: [][]any{ + nil, + }, + }, + want: []map[string]any{ + {}, + }, + }, + { + name: "empty columns", + result: elasticsearch.EsqlResult{ + Columns: []elasticsearch.EsqlColumn{}, + Values: [][]any{ + {}, + {}, + }, + }, + want: []map[string]any{ + {}, + {}, + }, + }, + { + name: "more values than columns", + result: elasticsearch.EsqlResult{ + Columns: []elasticsearch.EsqlColumn{ + {Name: "id", Type: "integer"}, + }, + Values: [][]any{ + {1, "extra"}, + }, + }, + want: []map[string]any{ + {"id": 1}, + }, + }, + { + name: "no columns but with values", + result: elasticsearch.EsqlResult{ + Columns: []elasticsearch.EsqlColumn{}, + Values: [][]any{ + {1, "data"}, + }, + }, + want: []map[string]any{ + {}, + }, + }, + } + for _, tt := range tests { + t1.Run(tt.name, func(t1 *testing.T) { + if got := elasticsearch.EsqlToMap(tt.result); !reflect.DeepEqual(got, tt.want) { + t1.Errorf("esqlToMap() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/sources/firebird/firebird.go b/internal/sources/firebird/firebird.go index 43775be70c..4be3d20cac 100644 --- a/internal/sources/firebird/firebird.go +++ b/internal/sources/firebird/firebird.go @@ -96,6 +96,53 @@ func (s *Source) FirebirdDB() *sql.DB { return s.Db } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + rows, err := s.FirebirdDB().QueryContext(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer rows.Close() + + cols, err := rows.Columns() + if err != nil { + return nil, fmt.Errorf("unable to get columns: %w", err) + } + + values := make([]any, len(cols)) + scanArgs := make([]any, len(values)) + for i := range values { + scanArgs[i] = &values[i] + } + + var out []any + for rows.Next() { + + err = rows.Scan(scanArgs...) + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + + vMap := make(map[string]any) + for i, col := range cols { + if b, ok := values[i].([]byte); ok { + vMap[col] = string(b) + } else { + vMap[col] = values[i] + } + } + out = append(out, vMap) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating rows: %w", err) + } + + // In most cases, DML/DDL statements like INSERT, UPDATE, CREATE, etc. might return no rows + // However, it is also possible that this was a query that was expected to return rows + // but returned none, a case that we cannot distinguish here. + return out, nil +} + func initFirebirdConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string) (*sql.DB, error) { _, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) defer span.End() diff --git a/internal/tools/couchbase/couchbase.go b/internal/tools/couchbase/couchbase.go index 481c9f6b22..a04a3b0aee 100644 --- a/internal/tools/couchbase/couchbase.go +++ b/internal/tools/couchbase/couchbase.go @@ -16,7 +16,6 @@ package couchbase import ( "context" - "encoding/json" "fmt" "github.com/couchbase/gocb/v2" @@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { CouchbaseScope() *gocb.Scope - CouchbaseQueryScanConsistency() uint + RunSQL(string, parameters.ParamValues) (any, error) } type Config struct { @@ -112,24 +111,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, fmt.Errorf("unable to extract standard params %w", err) } - results, err := source.CouchbaseScope().Query(newStatement, &gocb.QueryOptions{ - ScanConsistency: gocb.QueryScanConsistency(source.CouchbaseQueryScanConsistency()), - NamedParameters: newParams.AsMap(), - }) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - var out []any - for results.Next() { - var result json.RawMessage - err := results.Row(&result) - if err != nil { - return nil, fmt.Errorf("error processing row: %w", err) - } - out = append(out, result) - } - return out, nil + return source.RunSQL(newStatement, newParams) } func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/dgraph/dgraph.go b/internal/tools/dgraph/dgraph.go index beef9f86a5..2ed68fe209 100644 --- a/internal/tools/dgraph/dgraph.go +++ b/internal/tools/dgraph/dgraph.go @@ -16,7 +16,6 @@ package dgraph import ( "context" - "encoding/json" "fmt" yaml "github.com/goccy/go-yaml" @@ -44,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { DgraphClient() *dgraph.DgraphClient + RunSQL(string, parameters.ParamValues, bool, string) (any, error) } type Config struct { @@ -95,27 +95,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, err } - - paramsMap := params.AsMapWithDollarPrefix() - - resp, err := source.DgraphClient().ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout) - if err != nil { - return nil, err - } - - if err := dgraph.CheckError(resp); err != nil { - return nil, err - } - - var result struct { - Data map[string]interface{} `json:"data"` - } - - if err := json.Unmarshal(resp, &result); err != nil { - return nil, fmt.Errorf("error parsing JSON: %v", err) - } - - return result.Data, nil + return source.RunSQL(t.Statement, params, t.IsQuery, t.Timeout) } func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go b/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go index d7cbb35722..57f78a4403 100644 --- a/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go +++ b/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go @@ -15,14 +15,10 @@ package elasticsearchesql import ( - "bytes" "context" - "encoding/json" "fmt" "time" - "github.com/elastic/go-elasticsearch/v9/esapi" - "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/goccy/go-yaml" @@ -41,6 +37,7 @@ func init() { type compatibleSource interface { ElasticsearchClient() es.EsClient + RunSQL(ctx context.Context, format, query string, params []map[string]any) (any, error) } type Config struct { @@ -91,16 +88,6 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -type esqlColumn struct { - Name string `json:"name"` - Type string `json:"type"` -} - -type esqlResult struct { - Columns []esqlColumn `json:"columns"` - Values [][]any `json:"values"` -} - func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { @@ -116,20 +103,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para defer cancel() } - bodyStruct := struct { - Query string `json:"query"` - Params []map[string]any `json:"params,omitempty"` - }{ - Query: t.Query, - Params: make([]map[string]any, 0, len(params)), - } - + query := t.Query + sqlParams := make([]map[string]any, 0, len(params)) paramMap := params.AsMap() - // If a query is provided in the params and not already set in the tool, use it. - if query, ok := paramMap["query"]; ok { - if str, ok := query.(string); ok && bodyStruct.Query == "" { - bodyStruct.Query = str + if queryVal, ok := paramMap["query"]; ok { + if str, ok := queryVal.(string); ok && t.Query == "" { + query = str } // Drop the query param if not a string or if the tool already has a query. @@ -140,65 +120,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if param.GetType() == "array" { return nil, fmt.Errorf("array parameters are not supported yet") } - bodyStruct.Params = append(bodyStruct.Params, map[string]any{param.GetName(): paramMap[param.GetName()]}) + sqlParams = append(sqlParams, map[string]any{param.GetName(): paramMap[param.GetName()]}) } - - body, err := json.Marshal(bodyStruct) - if err != nil { - return nil, fmt.Errorf("failed to marshal query body: %w", err) - } - res, err := esapi.EsqlQueryRequest{ - Body: bytes.NewReader(body), - Format: t.Format, - FilterPath: []string{"columns", "values"}, - Instrument: source.ElasticsearchClient().InstrumentationEnabled(), - }.Do(ctx, source.ElasticsearchClient()) - - if err != nil { - return nil, err - } - defer res.Body.Close() - - if res.IsError() { - // Try to extract error message from response - var esErr json.RawMessage - err = util.DecodeJSON(res.Body, &esErr) - if err != nil { - return nil, fmt.Errorf("elasticsearch error: status %s", res.Status()) - } - return esErr, nil - } - - var result esqlResult - err = util.DecodeJSON(res.Body, &result) - if err != nil { - return nil, fmt.Errorf("failed to decode response body: %w", err) - } - - output := t.esqlToMap(result) - - return output, nil -} - -// esqlToMap converts the esqlResult to a slice of maps. -func (t Tool) esqlToMap(result esqlResult) []map[string]any { - output := make([]map[string]any, 0, len(result.Values)) - for _, value := range result.Values { - row := make(map[string]any) - if value == nil { - output = append(output, row) - continue - } - for i, col := range result.Columns { - if i < len(value) { - row[col.Name] = value[i] - } else { - row[col.Name] = nil - } - } - output = append(output, row) - } - return output + return source.RunSQL(ctx, t.Format, query, sqlParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql_test.go b/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql_test.go index 2382483429..ec65d2842a 100644 --- a/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql_test.go +++ b/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql_test.go @@ -15,7 +15,6 @@ package elasticsearchesql import ( - "reflect" "testing" "github.com/goccy/go-yaml" @@ -106,156 +105,3 @@ func TestParseFromYamlElasticsearchEsql(t *testing.T) { }) } } - -func TestTool_esqlToMap(t1 *testing.T) { - tests := []struct { - name string - result esqlResult - want []map[string]any - }{ - { - name: "simple case with two rows", - result: esqlResult{ - Columns: []esqlColumn{ - {Name: "first_name", Type: "text"}, - {Name: "last_name", Type: "text"}, - }, - Values: [][]any{ - {"John", "Doe"}, - {"Jane", "Smith"}, - }, - }, - want: []map[string]any{ - {"first_name": "John", "last_name": "Doe"}, - {"first_name": "Jane", "last_name": "Smith"}, - }, - }, - { - name: "different data types", - result: esqlResult{ - Columns: []esqlColumn{ - {Name: "id", Type: "integer"}, - {Name: "active", Type: "boolean"}, - {Name: "score", Type: "float"}, - }, - Values: [][]any{ - {1, true, 95.5}, - {2, false, 88.0}, - }, - }, - want: []map[string]any{ - {"id": 1, "active": true, "score": 95.5}, - {"id": 2, "active": false, "score": 88.0}, - }, - }, - { - name: "no rows", - result: esqlResult{ - Columns: []esqlColumn{ - {Name: "id", Type: "integer"}, - {Name: "name", Type: "text"}, - }, - Values: [][]any{}, - }, - want: []map[string]any{}, - }, - { - name: "null values", - result: esqlResult{ - Columns: []esqlColumn{ - {Name: "id", Type: "integer"}, - {Name: "name", Type: "text"}, - }, - Values: [][]any{ - {1, nil}, - {2, "Alice"}, - }, - }, - want: []map[string]any{ - {"id": 1, "name": nil}, - {"id": 2, "name": "Alice"}, - }, - }, - { - name: "missing values in a row", - result: esqlResult{ - Columns: []esqlColumn{ - {Name: "id", Type: "integer"}, - {Name: "name", Type: "text"}, - {Name: "age", Type: "integer"}, - }, - Values: [][]any{ - {1, "Bob"}, - {2, "Charlie", 30}, - }, - }, - want: []map[string]any{ - {"id": 1, "name": "Bob", "age": nil}, - {"id": 2, "name": "Charlie", "age": 30}, - }, - }, - { - name: "all null row", - result: esqlResult{ - Columns: []esqlColumn{ - {Name: "id", Type: "integer"}, - {Name: "name", Type: "text"}, - }, - Values: [][]any{ - nil, - }, - }, - want: []map[string]any{ - {}, - }, - }, - { - name: "empty columns", - result: esqlResult{ - Columns: []esqlColumn{}, - Values: [][]any{ - {}, - {}, - }, - }, - want: []map[string]any{ - {}, - {}, - }, - }, - { - name: "more values than columns", - result: esqlResult{ - Columns: []esqlColumn{ - {Name: "id", Type: "integer"}, - }, - Values: [][]any{ - {1, "extra"}, - }, - }, - want: []map[string]any{ - {"id": 1}, - }, - }, - { - name: "no columns but with values", - result: esqlResult{ - Columns: []esqlColumn{}, - Values: [][]any{ - {1, "data"}, - }, - }, - want: []map[string]any{ - {}, - }, - }, - } - for _, tt := range tests { - t1.Run(tt.name, func(t1 *testing.T) { - t := Tool{} - if got := t.esqlToMap(tt.result); !reflect.DeepEqual(got, tt.want) { - t1.Errorf("esqlToMap() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/internal/tools/firebird/firebirdexecutesql/firebirdexecutesql.go b/internal/tools/firebird/firebirdexecutesql/firebirdexecutesql.go index 28c8d0fb63..a6f6c01979 100644 --- a/internal/tools/firebird/firebirdexecutesql/firebirdexecutesql.go +++ b/internal/tools/firebird/firebirdexecutesql/firebirdexecutesql.go @@ -44,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { FirebirdDB() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -106,49 +107,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error getting logger: %s", err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - - rows, err := source.FirebirdDB().QueryContext(ctx, sql) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer rows.Close() - - cols, err := rows.Columns() - - var out []any - if err == nil && len(cols) > 0 { - values := make([]any, len(cols)) - scanArgs := make([]any, len(values)) - for i := range values { - scanArgs[i] = &values[i] - } - - for rows.Next() { - err = rows.Scan(scanArgs...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - - vMap := make(map[string]any) - for i, colName := range cols { - if b, ok := values[i].([]byte); ok { - vMap[colName] = string(b) - } else { - vMap[colName] = values[i] - } - } - out = append(out, vMap) - } - } - - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("error iterating rows: %w", err) - } - - // In most cases, DML/DDL statements like INSERT, UPDATE, CREATE, etc. might return no rows - // However, it is also possible that this was a query that was expected to return rows - // but returned none, a case that we cannot distinguish here. - return out, nil + return source.RunSQL(ctx, sql, nil) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/firebird/firebirdsql/firebirdsql.go b/internal/tools/firebird/firebirdsql/firebirdsql.go index 9dd040dcd7..74912714a7 100644 --- a/internal/tools/firebird/firebirdsql/firebirdsql.go +++ b/internal/tools/firebird/firebirdsql/firebirdsql.go @@ -44,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { FirebirdDB() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -125,51 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para namedArgs = append(namedArgs, value) } } - - rows, err := source.FirebirdDB().QueryContext(ctx, statement, namedArgs...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer rows.Close() - - cols, err := rows.Columns() - if err != nil { - return nil, fmt.Errorf("unable to get columns: %w", err) - } - - values := make([]any, len(cols)) - scanArgs := make([]any, len(values)) - for i := range values { - scanArgs[i] = &values[i] - } - - var out []any - for rows.Next() { - - err = rows.Scan(scanArgs...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - - vMap := make(map[string]any) - for i, col := range cols { - if b, ok := values[i].([]byte); ok { - vMap[col] = string(b) - } else { - vMap[col] = values[i] - } - } - out = append(out, vMap) - } - - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("error iterating rows: %w", err) - } - - // In most cases, DML/DDL statements like INSERT, UPDATE, CREATE, etc. might return no rows - // However, it is also possible that this was a query that was expected to return rows - // but returned none, a case that we cannot distinguish here. - return out, nil + return source.RunSQL(ctx, statement, namedArgs) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {