refactor: move source implementation in Invoke() function to Source (#2229)

Move source-related queries from `Invoke()` function into Source.

The following sources were updated in this PR:
* bigtable
* cassandra
* clickhouse
* cloud gemini data analytics

This is an effort to generalizing tools to work with any Source that
implements a specific interface. This will provide a better segregation
of the roles for Tools vs Source.

Tool's role will be limited to the following:
* Resolve any pre-implementation steps or parameters (e.g. template
parameters)
* Retrieving Source
* Calling the source's implementation


This PR also fix the following gemini review recommendations:
* Bigtable `resultRow.GetByName()` to throw an error and return false
* Clickhouselistdatabases and Clickhouselisttables to reuse the
`RunSQL()` function
This commit is contained in:
Yuan Teoh
2025-12-30 13:55:02 -08:00
committed by GitHub
parent 38d127a354
commit 285cdcd69a
11 changed files with 257 additions and 314 deletions

View File

@@ -22,6 +22,7 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"go.opentelemetry.io/otel/trace"
"google.golang.org/api/option"
)
@@ -88,6 +89,94 @@ func (s *Source) BigtableClient() *bigtable.Client {
return s.Client
}
func getBigtableType(paramType string) (bigtable.SQLType, error) {
switch paramType {
case "boolean":
return bigtable.BoolSQLType{}, nil
case "string":
return bigtable.StringSQLType{}, nil
case "integer":
return bigtable.Int64SQLType{}, nil
case "float":
return bigtable.Float64SQLType{}, nil
case "array":
return bigtable.ArraySQLType{}, nil
default:
return nil, fmt.Errorf("unknow param type %s", paramType)
}
}
func getMapParamsType(tparams parameters.Parameters) (map[string]bigtable.SQLType, error) {
btParamTypes := make(map[string]bigtable.SQLType)
for _, p := range tparams {
if p.GetType() == "array" {
itemType, err := getBigtableType(p.Manifest().Items.Type)
if err != nil {
return nil, err
}
btParamTypes[p.GetName()] = bigtable.ArraySQLType{
ElemType: itemType,
}
continue
}
paramType, err := getBigtableType(p.GetType())
if err != nil {
return nil, err
}
btParamTypes[p.GetName()] = paramType
}
return btParamTypes, nil
}
func (s *Source) RunSQL(ctx context.Context, statement string, configParam parameters.Parameters, params parameters.ParamValues) (any, error) {
mapParamsType, err := getMapParamsType(configParam)
if err != nil {
return nil, fmt.Errorf("fail to get map params: %w", err)
}
ps, err := s.BigtableClient().PrepareStatement(
ctx,
statement,
mapParamsType,
)
if err != nil {
return nil, fmt.Errorf("unable to prepare statement: %w", err)
}
bs, err := ps.Bind(params.AsMap())
if err != nil {
return nil, fmt.Errorf("unable to bind: %w", err)
}
var out []any
var rowErr error
err = bs.Execute(ctx, func(resultRow bigtable.ResultRow) bool {
vMap := make(map[string]any)
cols := resultRow.Metadata.Columns
for _, c := range cols {
var columValue any
if err = resultRow.GetByName(c.Name, &columValue); err != nil {
rowErr = err
return false
}
vMap[c.Name] = columValue
}
out = append(out, vMap)
return true
})
if err != nil {
return nil, fmt.Errorf("unable to execute client: %w", err)
}
if rowErr != nil {
return nil, fmt.Errorf("error processing row: %w", rowErr)
}
return out, nil
}
func initBigtableClient(ctx context.Context, tracer trace.Tracer, name, project, instance string) (*bigtable.Client, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)

View File

@@ -21,6 +21,7 @@ import (
gocql "github.com/apache/cassandra-gocql-driver/v2"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"go.opentelemetry.io/otel/trace"
)
@@ -89,10 +90,32 @@ func (s *Source) ToConfig() sources.SourceConfig {
}
// SourceKind implements sources.Source.
func (s Source) SourceKind() string {
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) RunSQL(ctx context.Context, statement string, params parameters.ParamValues) (any, error) {
sliceParams := params.AsSlice()
iter := s.CassandraSession().Query(statement, sliceParams...).IterContext(ctx)
// Create a slice to store the out
var out []map[string]interface{}
// Scan results into a map and append to the slice
for {
row := make(map[string]interface{}) // Create a new map for each row
if !iter.MapScan(row) {
break // No more rows
}
out = append(out, row)
}
if err := iter.Close(); err != nil {
return nil, fmt.Errorf("unable to parse rows: %w", err)
}
return out, nil
}
var _ sources.Source = &Source{}
func initCassandraSession(ctx context.Context, tracer trace.Tracer, c Config) (*gocql.Session, error) {

View File

@@ -24,6 +24,7 @@ import (
_ "github.com/ClickHouse/clickhouse-go/v2"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"go.opentelemetry.io/otel/trace"
)
@@ -99,6 +100,69 @@ func (s *Source) ClickHousePool() *sql.DB {
return s.Pool
}
func (s *Source) RunSQL(ctx context.Context, statement string, params parameters.ParamValues) (any, error) {
var sliceParams []any
if params != nil {
sliceParams = params.AsSlice()
}
results, err := s.ClickHousePool().QueryContext(ctx, statement, sliceParams...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
defer results.Close()
cols, err := results.Columns()
if err != nil {
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
}
// create an array of values for each column, which can be re-used to scan each row
rawValues := make([]any, len(cols))
values := make([]any, len(cols))
for i := range rawValues {
values[i] = &rawValues[i]
}
colTypes, err := results.ColumnTypes()
if err != nil {
return nil, fmt.Errorf("unable to get column types: %w", err)
}
var out []any
for results.Next() {
err := results.Scan(values...)
if err != nil {
return nil, fmt.Errorf("unable to parse row: %w", err)
}
vMap := make(map[string]any)
for i, name := range cols {
// ClickHouse driver may return specific types that need handling
switch colTypes[i].DatabaseTypeName() {
case "String", "FixedString":
if rawValues[i] != nil {
// Handle potential []byte to string conversion if needed
if b, ok := rawValues[i].([]byte); ok {
vMap[name] = string(b)
} else {
vMap[name] = rawValues[i]
}
} else {
vMap[name] = nil
}
default:
vMap[name] = rawValues[i]
}
}
out = append(out, vMap)
}
if err := results.Err(); err != nil {
return nil, fmt.Errorf("errors encountered by results.Scan: %w", err)
}
return out, nil
}
func validateConfig(protocol string) error {
validProtocols := map[string]bool{"http": true, "https": true}

View File

@@ -14,8 +14,11 @@
package cloudgda
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/goccy/go-yaml"
@@ -131,3 +134,43 @@ func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Clien
func (s *Source) UseClientAuthorization() bool {
return s.UseClientOAuth
}
func (s *Source) RunQuery(ctx context.Context, tokenStr string, bodyBytes []byte) (any, error) {
// The API endpoint itself always uses the "global" location.
apiLocation := "global"
apiParent := fmt.Sprintf("projects/%s/locations/%s", s.GetProjectID(), apiLocation)
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", s.GetBaseURL(), apiParent)
client, err := s.GetClient(ctx, tokenStr)
if err != nil {
return nil, fmt.Errorf("failed to get HTTP client: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody))
}
var result map[string]any
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
return result, nil
}