mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-08 23:18:04 -05:00
refactor: move source implementation in Invoke() function to Source (#2234)
Move source-related queries from `Invoke()` function into Source. The following sources are updated in this PR: * couchbase * dgraph * elasticsearch * firebird This is an effort to generalizing tools to work with any Source that implements a specific interface. This will provide a better segregation of the roles for Tools vs Source. Tool's role will be limited to the following: * Resolve any pre-implementation steps or parameters (e.g. template parameters) * Retrieving Source * Calling the source's implementation
This commit is contained in:
@@ -17,6 +17,7 @@ package couchbase
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
@@ -24,6 +25,7 @@ import (
|
|||||||
tlsutil "github.com/couchbase/tools-common/http/tls"
|
tlsutil "github.com/couchbase/tools-common/http/tls"
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -110,6 +112,27 @@ func (s *Source) CouchbaseQueryScanConsistency() uint {
|
|||||||
return s.QueryScanConsistency
|
return s.QueryScanConsistency
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Source) RunSQL(statement string, params parameters.ParamValues) (any, error) {
|
||||||
|
results, err := s.CouchbaseScope().Query(statement, &gocb.QueryOptions{
|
||||||
|
ScanConsistency: gocb.QueryScanConsistency(s.CouchbaseQueryScanConsistency()),
|
||||||
|
NamedParameters: params.AsMap(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var out []any
|
||||||
|
for results.Next() {
|
||||||
|
var result json.RawMessage
|
||||||
|
err := results.Row(&result)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error processing row: %w", err)
|
||||||
|
}
|
||||||
|
out = append(out, result)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r Config) createCouchbaseOptions() (gocb.ClusterOptions, error) {
|
func (r Config) createCouchbaseOptions() (gocb.ClusterOptions, error) {
|
||||||
cbOpts := gocb.ClusterOptions{}
|
cbOpts := gocb.ClusterOptions{}
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import (
|
|||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -114,6 +115,28 @@ func (s *Source) DgraphClient() *DgraphClient {
|
|||||||
return s.Client
|
return s.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Source) RunSQL(statement string, params parameters.ParamValues, isQuery bool, timeout string) (any, error) {
|
||||||
|
paramsMap := params.AsMapWithDollarPrefix()
|
||||||
|
resp, err := s.DgraphClient().ExecuteQuery(statement, paramsMap, isQuery, timeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := checkError(resp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Data map[string]interface{} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(resp, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.Data, nil
|
||||||
|
}
|
||||||
|
|
||||||
func initDgraphHttpClient(ctx context.Context, tracer trace.Tracer, r Config) (*DgraphClient, error) {
|
func initDgraphHttpClient(ctx context.Context, tracer trace.Tracer, r Config) (*DgraphClient, error) {
|
||||||
//nolint:all // Reassigned ctx
|
//nolint:all // Reassigned ctx
|
||||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, r.Name)
|
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, r.Name)
|
||||||
@@ -285,7 +308,7 @@ func (hc *DgraphClient) doLogin(creds map[string]interface{}) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := CheckError(resp); err != nil {
|
if err := checkError(resp); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -370,7 +393,7 @@ func getUrl(baseUrl, resource string, params url.Values) (string, error) {
|
|||||||
return u.String(), nil
|
return u.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckError(resp []byte) error {
|
func checkError(resp []byte) error {
|
||||||
var errResp struct {
|
var errResp struct {
|
||||||
Errors []struct {
|
Errors []struct {
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
|
|||||||
@@ -15,7 +15,9 @@
|
|||||||
package elasticsearch
|
package elasticsearch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
@@ -149,3 +151,80 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
|||||||
func (s *Source) ElasticsearchClient() EsClient {
|
func (s *Source) ElasticsearchClient() EsClient {
|
||||||
return s.Client
|
return s.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EsqlColumn struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EsqlResult struct {
|
||||||
|
Columns []EsqlColumn `json:"columns"`
|
||||||
|
Values [][]any `json:"values"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) RunSQL(ctx context.Context, format, query string, params []map[string]any) (any, error) {
|
||||||
|
bodyStruct := struct {
|
||||||
|
Query string `json:"query"`
|
||||||
|
Params []map[string]any `json:"params,omitempty"`
|
||||||
|
}{
|
||||||
|
Query: query,
|
||||||
|
Params: params,
|
||||||
|
}
|
||||||
|
body, err := json.Marshal(bodyStruct)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal query body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := esapi.EsqlQueryRequest{
|
||||||
|
Body: bytes.NewReader(body),
|
||||||
|
Format: format,
|
||||||
|
FilterPath: []string{"columns", "values"},
|
||||||
|
Instrument: s.ElasticsearchClient().InstrumentationEnabled(),
|
||||||
|
}.Do(ctx, s.ElasticsearchClient())
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if res.IsError() {
|
||||||
|
// Try to extract error message from response
|
||||||
|
var esErr json.RawMessage
|
||||||
|
err = util.DecodeJSON(res.Body, &esErr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("elasticsearch error: status %s", res.Status())
|
||||||
|
}
|
||||||
|
return esErr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var result EsqlResult
|
||||||
|
err = util.DecodeJSON(res.Body, &result)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode response body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := EsqlToMap(result)
|
||||||
|
|
||||||
|
return output, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EsqlToMap converts the esqlResult to a slice of maps.
|
||||||
|
func EsqlToMap(result EsqlResult) []map[string]any {
|
||||||
|
output := make([]map[string]any, 0, len(result.Values))
|
||||||
|
for _, value := range result.Values {
|
||||||
|
row := make(map[string]any)
|
||||||
|
if value == nil {
|
||||||
|
output = append(output, row)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for i, col := range result.Columns {
|
||||||
|
if i < len(value) {
|
||||||
|
row[col.Name] = value[i]
|
||||||
|
} else {
|
||||||
|
row[col.Name] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output = append(output, row)
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
package elasticsearch_test
|
package elasticsearch_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
@@ -64,3 +65,155 @@ func TestParseFromYamlElasticsearch(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTool_esqlToMap(t1 *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
result elasticsearch.EsqlResult
|
||||||
|
want []map[string]any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple case with two rows",
|
||||||
|
result: elasticsearch.EsqlResult{
|
||||||
|
Columns: []elasticsearch.EsqlColumn{
|
||||||
|
{Name: "first_name", Type: "text"},
|
||||||
|
{Name: "last_name", Type: "text"},
|
||||||
|
},
|
||||||
|
Values: [][]any{
|
||||||
|
{"John", "Doe"},
|
||||||
|
{"Jane", "Smith"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []map[string]any{
|
||||||
|
{"first_name": "John", "last_name": "Doe"},
|
||||||
|
{"first_name": "Jane", "last_name": "Smith"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different data types",
|
||||||
|
result: elasticsearch.EsqlResult{
|
||||||
|
Columns: []elasticsearch.EsqlColumn{
|
||||||
|
{Name: "id", Type: "integer"},
|
||||||
|
{Name: "active", Type: "boolean"},
|
||||||
|
{Name: "score", Type: "float"},
|
||||||
|
},
|
||||||
|
Values: [][]any{
|
||||||
|
{1, true, 95.5},
|
||||||
|
{2, false, 88.0},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []map[string]any{
|
||||||
|
{"id": 1, "active": true, "score": 95.5},
|
||||||
|
{"id": 2, "active": false, "score": 88.0},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no rows",
|
||||||
|
result: elasticsearch.EsqlResult{
|
||||||
|
Columns: []elasticsearch.EsqlColumn{
|
||||||
|
{Name: "id", Type: "integer"},
|
||||||
|
{Name: "name", Type: "text"},
|
||||||
|
},
|
||||||
|
Values: [][]any{},
|
||||||
|
},
|
||||||
|
want: []map[string]any{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "null values",
|
||||||
|
result: elasticsearch.EsqlResult{
|
||||||
|
Columns: []elasticsearch.EsqlColumn{
|
||||||
|
{Name: "id", Type: "integer"},
|
||||||
|
{Name: "name", Type: "text"},
|
||||||
|
},
|
||||||
|
Values: [][]any{
|
||||||
|
{1, nil},
|
||||||
|
{2, "Alice"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []map[string]any{
|
||||||
|
{"id": 1, "name": nil},
|
||||||
|
{"id": 2, "name": "Alice"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing values in a row",
|
||||||
|
result: elasticsearch.EsqlResult{
|
||||||
|
Columns: []elasticsearch.EsqlColumn{
|
||||||
|
{Name: "id", Type: "integer"},
|
||||||
|
{Name: "name", Type: "text"},
|
||||||
|
{Name: "age", Type: "integer"},
|
||||||
|
},
|
||||||
|
Values: [][]any{
|
||||||
|
{1, "Bob"},
|
||||||
|
{2, "Charlie", 30},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []map[string]any{
|
||||||
|
{"id": 1, "name": "Bob", "age": nil},
|
||||||
|
{"id": 2, "name": "Charlie", "age": 30},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all null row",
|
||||||
|
result: elasticsearch.EsqlResult{
|
||||||
|
Columns: []elasticsearch.EsqlColumn{
|
||||||
|
{Name: "id", Type: "integer"},
|
||||||
|
{Name: "name", Type: "text"},
|
||||||
|
},
|
||||||
|
Values: [][]any{
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []map[string]any{
|
||||||
|
{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty columns",
|
||||||
|
result: elasticsearch.EsqlResult{
|
||||||
|
Columns: []elasticsearch.EsqlColumn{},
|
||||||
|
Values: [][]any{
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []map[string]any{
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "more values than columns",
|
||||||
|
result: elasticsearch.EsqlResult{
|
||||||
|
Columns: []elasticsearch.EsqlColumn{
|
||||||
|
{Name: "id", Type: "integer"},
|
||||||
|
},
|
||||||
|
Values: [][]any{
|
||||||
|
{1, "extra"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []map[string]any{
|
||||||
|
{"id": 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no columns but with values",
|
||||||
|
result: elasticsearch.EsqlResult{
|
||||||
|
Columns: []elasticsearch.EsqlColumn{},
|
||||||
|
Values: [][]any{
|
||||||
|
{1, "data"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []map[string]any{
|
||||||
|
{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t1.Run(tt.name, func(t1 *testing.T) {
|
||||||
|
if got := elasticsearch.EsqlToMap(tt.result); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t1.Errorf("esqlToMap() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -96,6 +96,53 @@ func (s *Source) FirebirdDB() *sql.DB {
|
|||||||
return s.Db
|
return s.Db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) {
|
||||||
|
rows, err := s.FirebirdDB().QueryContext(ctx, statement, params...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
cols, err := rows.Columns()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to get columns: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
values := make([]any, len(cols))
|
||||||
|
scanArgs := make([]any, len(values))
|
||||||
|
for i := range values {
|
||||||
|
scanArgs[i] = &values[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
var out []any
|
||||||
|
for rows.Next() {
|
||||||
|
|
||||||
|
err = rows.Scan(scanArgs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
vMap := make(map[string]any)
|
||||||
|
for i, col := range cols {
|
||||||
|
if b, ok := values[i].([]byte); ok {
|
||||||
|
vMap[col] = string(b)
|
||||||
|
} else {
|
||||||
|
vMap[col] = values[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out = append(out, vMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// In most cases, DML/DDL statements like INSERT, UPDATE, CREATE, etc. might return no rows
|
||||||
|
// However, it is also possible that this was a query that was expected to return rows
|
||||||
|
// but returned none, a case that we cannot distinguish here.
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
func initFirebirdConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string) (*sql.DB, error) {
|
func initFirebirdConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string) (*sql.DB, error) {
|
||||||
_, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
_, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ package couchbase
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/couchbase/gocb/v2"
|
"github.com/couchbase/gocb/v2"
|
||||||
@@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
CouchbaseScope() *gocb.Scope
|
CouchbaseScope() *gocb.Scope
|
||||||
CouchbaseQueryScanConsistency() uint
|
RunSQL(string, parameters.ParamValues) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -112,24 +111,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||||
}
|
}
|
||||||
results, err := source.CouchbaseScope().Query(newStatement, &gocb.QueryOptions{
|
return source.RunSQL(newStatement, newParams)
|
||||||
ScanConsistency: gocb.QueryScanConsistency(source.CouchbaseQueryScanConsistency()),
|
|
||||||
NamedParameters: newParams.AsMap(),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var out []any
|
|
||||||
for results.Next() {
|
|
||||||
var result json.RawMessage
|
|
||||||
err := results.Row(&result)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error processing row: %w", err)
|
|
||||||
}
|
|
||||||
out = append(out, result)
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ package dgraph
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
@@ -44,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
DgraphClient() *dgraph.DgraphClient
|
DgraphClient() *dgraph.DgraphClient
|
||||||
|
RunSQL(string, parameters.ParamValues, bool, string) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -95,27 +95,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return source.RunSQL(t.Statement, params, t.IsQuery, t.Timeout)
|
||||||
paramsMap := params.AsMapWithDollarPrefix()
|
|
||||||
|
|
||||||
resp, err := source.DgraphClient().ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := dgraph.CheckError(resp); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var result struct {
|
|
||||||
Data map[string]interface{} `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(resp, &result); err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing JSON: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.Data, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -15,14 +15,10 @@
|
|||||||
package elasticsearchesql
|
package elasticsearchesql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/elastic/go-elasticsearch/v9/esapi"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
@@ -41,6 +37,7 @@ func init() {
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
ElasticsearchClient() es.EsClient
|
ElasticsearchClient() es.EsClient
|
||||||
|
RunSQL(ctx context.Context, format, query string, params []map[string]any) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -91,16 +88,6 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
return t.Config
|
return t.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
type esqlColumn struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type esqlResult struct {
|
|
||||||
Columns []esqlColumn `json:"columns"`
|
|
||||||
Values [][]any `json:"values"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -116,20 +103,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
bodyStruct := struct {
|
query := t.Query
|
||||||
Query string `json:"query"`
|
sqlParams := make([]map[string]any, 0, len(params))
|
||||||
Params []map[string]any `json:"params,omitempty"`
|
|
||||||
}{
|
|
||||||
Query: t.Query,
|
|
||||||
Params: make([]map[string]any, 0, len(params)),
|
|
||||||
}
|
|
||||||
|
|
||||||
paramMap := params.AsMap()
|
paramMap := params.AsMap()
|
||||||
|
|
||||||
// If a query is provided in the params and not already set in the tool, use it.
|
// If a query is provided in the params and not already set in the tool, use it.
|
||||||
if query, ok := paramMap["query"]; ok {
|
if queryVal, ok := paramMap["query"]; ok {
|
||||||
if str, ok := query.(string); ok && bodyStruct.Query == "" {
|
if str, ok := queryVal.(string); ok && t.Query == "" {
|
||||||
bodyStruct.Query = str
|
query = str
|
||||||
}
|
}
|
||||||
|
|
||||||
// Drop the query param if not a string or if the tool already has a query.
|
// Drop the query param if not a string or if the tool already has a query.
|
||||||
@@ -140,65 +120,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if param.GetType() == "array" {
|
if param.GetType() == "array" {
|
||||||
return nil, fmt.Errorf("array parameters are not supported yet")
|
return nil, fmt.Errorf("array parameters are not supported yet")
|
||||||
}
|
}
|
||||||
bodyStruct.Params = append(bodyStruct.Params, map[string]any{param.GetName(): paramMap[param.GetName()]})
|
sqlParams = append(sqlParams, map[string]any{param.GetName(): paramMap[param.GetName()]})
|
||||||
}
|
}
|
||||||
|
return source.RunSQL(ctx, t.Format, query, sqlParams)
|
||||||
body, err := json.Marshal(bodyStruct)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to marshal query body: %w", err)
|
|
||||||
}
|
|
||||||
res, err := esapi.EsqlQueryRequest{
|
|
||||||
Body: bytes.NewReader(body),
|
|
||||||
Format: t.Format,
|
|
||||||
FilterPath: []string{"columns", "values"},
|
|
||||||
Instrument: source.ElasticsearchClient().InstrumentationEnabled(),
|
|
||||||
}.Do(ctx, source.ElasticsearchClient())
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer res.Body.Close()
|
|
||||||
|
|
||||||
if res.IsError() {
|
|
||||||
// Try to extract error message from response
|
|
||||||
var esErr json.RawMessage
|
|
||||||
err = util.DecodeJSON(res.Body, &esErr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("elasticsearch error: status %s", res.Status())
|
|
||||||
}
|
|
||||||
return esErr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var result esqlResult
|
|
||||||
err = util.DecodeJSON(res.Body, &result)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to decode response body: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
output := t.esqlToMap(result)
|
|
||||||
|
|
||||||
return output, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// esqlToMap converts the esqlResult to a slice of maps.
|
|
||||||
func (t Tool) esqlToMap(result esqlResult) []map[string]any {
|
|
||||||
output := make([]map[string]any, 0, len(result.Values))
|
|
||||||
for _, value := range result.Values {
|
|
||||||
row := make(map[string]any)
|
|
||||||
if value == nil {
|
|
||||||
output = append(output, row)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for i, col := range result.Columns {
|
|
||||||
if i < len(value) {
|
|
||||||
row[col.Name] = value[i]
|
|
||||||
} else {
|
|
||||||
row[col.Name] = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
output = append(output, row)
|
|
||||||
}
|
|
||||||
return output
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
package elasticsearchesql
|
package elasticsearchesql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
@@ -106,156 +105,3 @@ func TestParseFromYamlElasticsearchEsql(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTool_esqlToMap(t1 *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
result esqlResult
|
|
||||||
want []map[string]any
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "simple case with two rows",
|
|
||||||
result: esqlResult{
|
|
||||||
Columns: []esqlColumn{
|
|
||||||
{Name: "first_name", Type: "text"},
|
|
||||||
{Name: "last_name", Type: "text"},
|
|
||||||
},
|
|
||||||
Values: [][]any{
|
|
||||||
{"John", "Doe"},
|
|
||||||
{"Jane", "Smith"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: []map[string]any{
|
|
||||||
{"first_name": "John", "last_name": "Doe"},
|
|
||||||
{"first_name": "Jane", "last_name": "Smith"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "different data types",
|
|
||||||
result: esqlResult{
|
|
||||||
Columns: []esqlColumn{
|
|
||||||
{Name: "id", Type: "integer"},
|
|
||||||
{Name: "active", Type: "boolean"},
|
|
||||||
{Name: "score", Type: "float"},
|
|
||||||
},
|
|
||||||
Values: [][]any{
|
|
||||||
{1, true, 95.5},
|
|
||||||
{2, false, 88.0},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: []map[string]any{
|
|
||||||
{"id": 1, "active": true, "score": 95.5},
|
|
||||||
{"id": 2, "active": false, "score": 88.0},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no rows",
|
|
||||||
result: esqlResult{
|
|
||||||
Columns: []esqlColumn{
|
|
||||||
{Name: "id", Type: "integer"},
|
|
||||||
{Name: "name", Type: "text"},
|
|
||||||
},
|
|
||||||
Values: [][]any{},
|
|
||||||
},
|
|
||||||
want: []map[string]any{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "null values",
|
|
||||||
result: esqlResult{
|
|
||||||
Columns: []esqlColumn{
|
|
||||||
{Name: "id", Type: "integer"},
|
|
||||||
{Name: "name", Type: "text"},
|
|
||||||
},
|
|
||||||
Values: [][]any{
|
|
||||||
{1, nil},
|
|
||||||
{2, "Alice"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: []map[string]any{
|
|
||||||
{"id": 1, "name": nil},
|
|
||||||
{"id": 2, "name": "Alice"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "missing values in a row",
|
|
||||||
result: esqlResult{
|
|
||||||
Columns: []esqlColumn{
|
|
||||||
{Name: "id", Type: "integer"},
|
|
||||||
{Name: "name", Type: "text"},
|
|
||||||
{Name: "age", Type: "integer"},
|
|
||||||
},
|
|
||||||
Values: [][]any{
|
|
||||||
{1, "Bob"},
|
|
||||||
{2, "Charlie", 30},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: []map[string]any{
|
|
||||||
{"id": 1, "name": "Bob", "age": nil},
|
|
||||||
{"id": 2, "name": "Charlie", "age": 30},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "all null row",
|
|
||||||
result: esqlResult{
|
|
||||||
Columns: []esqlColumn{
|
|
||||||
{Name: "id", Type: "integer"},
|
|
||||||
{Name: "name", Type: "text"},
|
|
||||||
},
|
|
||||||
Values: [][]any{
|
|
||||||
nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: []map[string]any{
|
|
||||||
{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty columns",
|
|
||||||
result: esqlResult{
|
|
||||||
Columns: []esqlColumn{},
|
|
||||||
Values: [][]any{
|
|
||||||
{},
|
|
||||||
{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: []map[string]any{
|
|
||||||
{},
|
|
||||||
{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "more values than columns",
|
|
||||||
result: esqlResult{
|
|
||||||
Columns: []esqlColumn{
|
|
||||||
{Name: "id", Type: "integer"},
|
|
||||||
},
|
|
||||||
Values: [][]any{
|
|
||||||
{1, "extra"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: []map[string]any{
|
|
||||||
{"id": 1},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no columns but with values",
|
|
||||||
result: esqlResult{
|
|
||||||
Columns: []esqlColumn{},
|
|
||||||
Values: [][]any{
|
|
||||||
{1, "data"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: []map[string]any{
|
|
||||||
{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t1.Run(tt.name, func(t1 *testing.T) {
|
|
||||||
t := Tool{}
|
|
||||||
if got := t.esqlToMap(tt.result); !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t1.Errorf("esqlToMap() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
FirebirdDB() *sql.DB
|
FirebirdDB() *sql.DB
|
||||||
|
RunSQL(context.Context, string, []any) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -106,49 +107,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, fmt.Errorf("error getting logger: %s", err)
|
return nil, fmt.Errorf("error getting logger: %s", err)
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql))
|
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql))
|
||||||
|
return source.RunSQL(ctx, sql, nil)
|
||||||
rows, err := source.FirebirdDB().QueryContext(ctx, sql)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
cols, err := rows.Columns()
|
|
||||||
|
|
||||||
var out []any
|
|
||||||
if err == nil && len(cols) > 0 {
|
|
||||||
values := make([]any, len(cols))
|
|
||||||
scanArgs := make([]any, len(values))
|
|
||||||
for i := range values {
|
|
||||||
scanArgs[i] = &values[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
err = rows.Scan(scanArgs...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
vMap := make(map[string]any)
|
|
||||||
for i, colName := range cols {
|
|
||||||
if b, ok := values[i].([]byte); ok {
|
|
||||||
vMap[colName] = string(b)
|
|
||||||
} else {
|
|
||||||
vMap[colName] = values[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out = append(out, vMap)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rows.Err(); err != nil {
|
|
||||||
return nil, fmt.Errorf("error iterating rows: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// In most cases, DML/DDL statements like INSERT, UPDATE, CREATE, etc. might return no rows
|
|
||||||
// However, it is also possible that this was a query that was expected to return rows
|
|
||||||
// but returned none, a case that we cannot distinguish here.
|
|
||||||
return out, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
FirebirdDB() *sql.DB
|
FirebirdDB() *sql.DB
|
||||||
|
RunSQL(context.Context, string, []any) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -125,51 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
namedArgs = append(namedArgs, value)
|
namedArgs = append(namedArgs, value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return source.RunSQL(ctx, statement, namedArgs)
|
||||||
rows, err := source.FirebirdDB().QueryContext(ctx, statement, namedArgs...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
cols, err := rows.Columns()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to get columns: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
values := make([]any, len(cols))
|
|
||||||
scanArgs := make([]any, len(values))
|
|
||||||
for i := range values {
|
|
||||||
scanArgs[i] = &values[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
var out []any
|
|
||||||
for rows.Next() {
|
|
||||||
|
|
||||||
err = rows.Scan(scanArgs...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
vMap := make(map[string]any)
|
|
||||||
for i, col := range cols {
|
|
||||||
if b, ok := values[i].([]byte); ok {
|
|
||||||
vMap[col] = string(b)
|
|
||||||
} else {
|
|
||||||
vMap[col] = values[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out = append(out, vMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rows.Err(); err != nil {
|
|
||||||
return nil, fmt.Errorf("error iterating rows: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// In most cases, DML/DDL statements like INSERT, UPDATE, CREATE, etc. might return no rows
|
|
||||||
// However, it is also possible that this was a query that was expected to return rows
|
|
||||||
// but returned none, a case that we cannot distinguish here.
|
|
||||||
return out, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user