mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-05-02 03:00:36 -04:00
refactor: move source implementation in Invoke() function to Source (#2229)
Move source-related queries from `Invoke()` function into Source. The following sources were updated in this PR: * bigtable * cassandra * clickhouse * cloud gemini data analytics This is an effort to generalizing tools to work with any Source that implements a specific interface. This will provide a better segregation of the roles for Tools vs Source. Tool's role will be limited to the following: * Resolve any pre-implementation steps or parameters (e.g. template parameters) * Retrieving Source * Calling the source's implementation This PR also fix the following gemini review recommendations: * Bigtable `resultRow.GetByName()` to throw an error and return false * Clickhouselistdatabases and Clickhouselisttables to reuse the `RunSQL()` function
This commit is contained in:
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
@@ -88,6 +89,94 @@ func (s *Source) BigtableClient() *bigtable.Client {
|
||||
return s.Client
|
||||
}
|
||||
|
||||
func getBigtableType(paramType string) (bigtable.SQLType, error) {
|
||||
switch paramType {
|
||||
case "boolean":
|
||||
return bigtable.BoolSQLType{}, nil
|
||||
case "string":
|
||||
return bigtable.StringSQLType{}, nil
|
||||
case "integer":
|
||||
return bigtable.Int64SQLType{}, nil
|
||||
case "float":
|
||||
return bigtable.Float64SQLType{}, nil
|
||||
case "array":
|
||||
return bigtable.ArraySQLType{}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknow param type %s", paramType)
|
||||
}
|
||||
}
|
||||
|
||||
func getMapParamsType(tparams parameters.Parameters) (map[string]bigtable.SQLType, error) {
|
||||
btParamTypes := make(map[string]bigtable.SQLType)
|
||||
for _, p := range tparams {
|
||||
if p.GetType() == "array" {
|
||||
itemType, err := getBigtableType(p.Manifest().Items.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
btParamTypes[p.GetName()] = bigtable.ArraySQLType{
|
||||
ElemType: itemType,
|
||||
}
|
||||
continue
|
||||
}
|
||||
paramType, err := getBigtableType(p.GetType())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
btParamTypes[p.GetName()] = paramType
|
||||
}
|
||||
return btParamTypes, nil
|
||||
}
|
||||
|
||||
func (s *Source) RunSQL(ctx context.Context, statement string, configParam parameters.Parameters, params parameters.ParamValues) (any, error) {
|
||||
mapParamsType, err := getMapParamsType(configParam)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fail to get map params: %w", err)
|
||||
}
|
||||
|
||||
ps, err := s.BigtableClient().PrepareStatement(
|
||||
ctx,
|
||||
statement,
|
||||
mapParamsType,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to prepare statement: %w", err)
|
||||
}
|
||||
|
||||
bs, err := ps.Bind(params.AsMap())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to bind: %w", err)
|
||||
}
|
||||
|
||||
var out []any
|
||||
var rowErr error
|
||||
err = bs.Execute(ctx, func(resultRow bigtable.ResultRow) bool {
|
||||
vMap := make(map[string]any)
|
||||
cols := resultRow.Metadata.Columns
|
||||
|
||||
for _, c := range cols {
|
||||
var columValue any
|
||||
if err = resultRow.GetByName(c.Name, &columValue); err != nil {
|
||||
rowErr = err
|
||||
return false
|
||||
}
|
||||
vMap[c.Name] = columValue
|
||||
}
|
||||
|
||||
out = append(out, vMap)
|
||||
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute client: %w", err)
|
||||
}
|
||||
if rowErr != nil {
|
||||
return nil, fmt.Errorf("error processing row: %w", rowErr)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func initBigtableClient(ctx context.Context, tracer trace.Tracer, name, project, instance string) (*bigtable.Client, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
gocql "github.com/apache/cassandra-gocql-driver/v2"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
@@ -89,10 +90,32 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
||||
}
|
||||
|
||||
// SourceKind implements sources.Source.
|
||||
func (s Source) SourceKind() string {
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) RunSQL(ctx context.Context, statement string, params parameters.ParamValues) (any, error) {
|
||||
sliceParams := params.AsSlice()
|
||||
iter := s.CassandraSession().Query(statement, sliceParams...).IterContext(ctx)
|
||||
|
||||
// Create a slice to store the out
|
||||
var out []map[string]interface{}
|
||||
|
||||
// Scan results into a map and append to the slice
|
||||
for {
|
||||
row := make(map[string]interface{}) // Create a new map for each row
|
||||
if !iter.MapScan(row) {
|
||||
break // No more rows
|
||||
}
|
||||
out = append(out, row)
|
||||
}
|
||||
|
||||
if err := iter.Close(); err != nil {
|
||||
return nil, fmt.Errorf("unable to parse rows: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
func initCassandraSession(ctx context.Context, tracer trace.Tracer, c Config) (*gocql.Session, error) {
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
_ "github.com/ClickHouse/clickhouse-go/v2"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
@@ -99,6 +100,69 @@ func (s *Source) ClickHousePool() *sql.DB {
|
||||
return s.Pool
|
||||
}
|
||||
|
||||
func (s *Source) RunSQL(ctx context.Context, statement string, params parameters.ParamValues) (any, error) {
|
||||
var sliceParams []any
|
||||
if params != nil {
|
||||
sliceParams = params.AsSlice()
|
||||
}
|
||||
results, err := s.ClickHousePool().QueryContext(ctx, statement, sliceParams...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
defer results.Close()
|
||||
|
||||
cols, err := results.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
|
||||
}
|
||||
|
||||
// create an array of values for each column, which can be re-used to scan each row
|
||||
rawValues := make([]any, len(cols))
|
||||
values := make([]any, len(cols))
|
||||
for i := range rawValues {
|
||||
values[i] = &rawValues[i]
|
||||
}
|
||||
|
||||
colTypes, err := results.ColumnTypes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get column types: %w", err)
|
||||
}
|
||||
|
||||
var out []any
|
||||
for results.Next() {
|
||||
err := results.Scan(values...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
vMap := make(map[string]any)
|
||||
for i, name := range cols {
|
||||
// ClickHouse driver may return specific types that need handling
|
||||
switch colTypes[i].DatabaseTypeName() {
|
||||
case "String", "FixedString":
|
||||
if rawValues[i] != nil {
|
||||
// Handle potential []byte to string conversion if needed
|
||||
if b, ok := rawValues[i].([]byte); ok {
|
||||
vMap[name] = string(b)
|
||||
} else {
|
||||
vMap[name] = rawValues[i]
|
||||
}
|
||||
} else {
|
||||
vMap[name] = nil
|
||||
}
|
||||
default:
|
||||
vMap[name] = rawValues[i]
|
||||
}
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
if err := results.Err(); err != nil {
|
||||
return nil, fmt.Errorf("errors encountered by results.Scan: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func validateConfig(protocol string) error {
|
||||
validProtocols := map[string]bool{"http": true, "https": true}
|
||||
|
||||
|
||||
@@ -14,8 +14,11 @@
|
||||
package cloudgda
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
@@ -131,3 +134,43 @@ func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Clien
|
||||
func (s *Source) UseClientAuthorization() bool {
|
||||
return s.UseClientOAuth
|
||||
}
|
||||
|
||||
func (s *Source) RunQuery(ctx context.Context, tokenStr string, bodyBytes []byte) (any, error) {
|
||||
// The API endpoint itself always uses the "global" location.
|
||||
apiLocation := "global"
|
||||
apiParent := fmt.Sprintf("projects/%s/locations/%s", s.GetProjectID(), apiLocation)
|
||||
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", s.GetBaseURL(), apiParent)
|
||||
|
||||
client, err := s.GetClient(ctx, tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get HTTP client: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user