mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 23:48:04 -05:00
refactor: move source implementation in Invoke() function to Source (#2236)
Move source-related queries from `Invoke()` function into Source. The following sources are updated in this PR: * mindsdb * oceanbase * oracle * redis * singlestore * cloudmonitoring This is an effort to generalizing tools to work with any Source that implements a specific interface. This will provide a better segregation of the roles for Tools vs Source. Tool's role will be limited to the following: * Resolve any pre-implementation steps or parameters (e.g. template parameters) * Retrieving Source * Calling the source's implementation * reduce oracle integration test coverage to 20%. There's no code change or test reduction in this PR. It might be because the Invoke() function was dedupe, hence the total line covered is reduced.
This commit is contained in:
@@ -875,8 +875,8 @@ steps:
|
|||||||
total_coverage=$(go tool cover -func=oracle_coverage.out | grep "total:" | awk '{print $3}')
|
total_coverage=$(go tool cover -func=oracle_coverage.out | grep "total:" | awk '{print $3}')
|
||||||
echo "Oracle total coverage: $total_coverage"
|
echo "Oracle total coverage: $total_coverage"
|
||||||
coverage_numeric=$(echo "$total_coverage" | sed 's/%//')
|
coverage_numeric=$(echo "$total_coverage" | sed 's/%//')
|
||||||
if awk -v cov="$coverage_numeric" 'BEGIN {exit !(cov < 30)}'; then
|
if awk -v cov="$coverage_numeric" 'BEGIN {exit !(cov < 20)}'; then
|
||||||
echo "Coverage failure: $total_coverage is below 30%."
|
echo "Coverage failure: $total_coverage is below 20%."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -37,7 +37,6 @@ require (
|
|||||||
github.com/google/go-cmp v0.7.0
|
github.com/google/go-cmp v0.7.0
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/jackc/pgx/v5 v5.7.6
|
github.com/jackc/pgx/v5 v5.7.6
|
||||||
github.com/json-iterator/go v1.1.12
|
|
||||||
github.com/looker-open-source/sdk-codegen/go v0.25.21
|
github.com/looker-open-source/sdk-codegen/go v0.25.21
|
||||||
github.com/microsoft/go-mssqldb v1.9.3
|
github.com/microsoft/go-mssqldb v1.9.3
|
||||||
github.com/nakagami/firebirdsql v0.9.15
|
github.com/nakagami/firebirdsql v0.9.15
|
||||||
@@ -138,6 +137,7 @@ require (
|
|||||||
github.com/jcmturner/goidentity/v6 v6.0.1 // indirect
|
github.com/jcmturner/goidentity/v6 v6.0.1 // indirect
|
||||||
github.com/jcmturner/gokrb5/v8 v8.4.4 // indirect
|
github.com/jcmturner/gokrb5/v8 v8.4.4 // indirect
|
||||||
github.com/jcmturner/rpc/v2 v2.0.3 // indirect
|
github.com/jcmturner/rpc/v2 v2.0.3 // indirect
|
||||||
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect
|
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect
|
||||||
github.com/klauspost/compress v1.18.0 // indirect
|
github.com/klauspost/compress v1.18.0 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.2.11 // indirect
|
github.com/klauspost/cpuid/v2 v2.2.11 // indirect
|
||||||
|
|||||||
@@ -15,7 +15,9 @@ package cloudmonitoring
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
@@ -131,3 +133,44 @@ func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Clien
|
|||||||
func (s *Source) UseClientAuthorization() bool {
|
func (s *Source) UseClientAuthorization() bool {
|
||||||
return s.UseClientOAuth
|
return s.UseClientOAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Source) RunQuery(projectID, query string) (any, error) {
|
||||||
|
url := fmt.Sprintf("%s/v1/projects/%s/location/global/prometheus/api/v1/query", s.BaseURL(), projectID)
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
q := req.URL.Query()
|
||||||
|
q.Add("query", query)
|
||||||
|
req.URL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
req.Header.Set("User-Agent", s.UserAgent())
|
||||||
|
|
||||||
|
resp, err := s.Client().Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("request failed: %s, body: %s", resp.Status, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]any
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal json: %w, body: %s", err, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
_ "github.com/go-sql-driver/mysql"
|
_ "github.com/go-sql-driver/mysql"
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon"
|
||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -101,6 +102,61 @@ func (s *Source) MySQLPool() *sql.DB {
|
|||||||
return s.Pool
|
return s.Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) {
|
||||||
|
// MindsDB now supports MySQL prepared statements natively
|
||||||
|
results, err := s.MindsDBPool().QueryContext(ctx, statement, params...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cols, err := results.Columns()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create an array of values for each column, which can be re-used to scan each row
|
||||||
|
rawValues := make([]any, len(cols))
|
||||||
|
values := make([]any, len(cols))
|
||||||
|
for i := range rawValues {
|
||||||
|
values[i] = &rawValues[i]
|
||||||
|
}
|
||||||
|
defer results.Close()
|
||||||
|
|
||||||
|
colTypes, err := results.ColumnTypes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to get column types: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var out []any
|
||||||
|
for results.Next() {
|
||||||
|
err := results.Scan(values...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||||
|
}
|
||||||
|
vMap := make(map[string]any)
|
||||||
|
for i, name := range cols {
|
||||||
|
val := rawValues[i]
|
||||||
|
if val == nil {
|
||||||
|
vMap[name] = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// MindsDB uses mysql driver
|
||||||
|
vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("errors encountered when converting values: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out = append(out, vMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := results.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
func initMindsDBConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
|
func initMindsDBConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
|
||||||
//nolint:all // Reassigned ctx
|
//nolint:all // Reassigned ctx
|
||||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
_ "github.com/go-sql-driver/mysql"
|
_ "github.com/go-sql-driver/mysql"
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon"
|
||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -97,6 +98,60 @@ func (s *Source) OceanBasePool() *sql.DB {
|
|||||||
return s.Pool
|
return s.Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) {
|
||||||
|
results, err := s.OceanBasePool().QueryContext(ctx, statement, params...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cols, err := results.Columns()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create an array of values for each column, which can be re-used to scan each row
|
||||||
|
rawValues := make([]any, len(cols))
|
||||||
|
values := make([]any, len(cols))
|
||||||
|
for i := range rawValues {
|
||||||
|
values[i] = &rawValues[i]
|
||||||
|
}
|
||||||
|
defer results.Close()
|
||||||
|
|
||||||
|
colTypes, err := results.ColumnTypes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to get column types: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var out []any
|
||||||
|
for results.Next() {
|
||||||
|
err := results.Scan(values...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||||
|
}
|
||||||
|
vMap := make(map[string]any)
|
||||||
|
for i, name := range cols {
|
||||||
|
val := rawValues[i]
|
||||||
|
if val == nil {
|
||||||
|
vMap[name] = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// oceanbase uses mysql driver
|
||||||
|
vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("errors encountered when converting values: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out = append(out, vMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := results.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
func initOceanBaseConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
|
func initOceanBaseConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
|
||||||
_, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
_, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package oracle
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -135,6 +136,107 @@ func (s *Source) OracleDB() *sql.DB {
|
|||||||
return s.DB
|
return s.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) {
|
||||||
|
rows, err := s.OracleDB().QueryContext(ctx, statement, params...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
// If Columns() errors, it might be a DDL/DML without an OUTPUT clause.
|
||||||
|
// We proceed, and results.Err() will catch actual query execution errors.
|
||||||
|
// 'out' will remain nil if cols is empty or err is not nil here.
|
||||||
|
cols, _ := rows.Columns()
|
||||||
|
|
||||||
|
// Get Column types
|
||||||
|
colTypes, err := rows.ColumnTypes()
|
||||||
|
if err != nil {
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("query execution error: %w", err)
|
||||||
|
}
|
||||||
|
return []any{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var out []any
|
||||||
|
for rows.Next() {
|
||||||
|
values := make([]any, len(cols))
|
||||||
|
for i, colType := range colTypes {
|
||||||
|
switch strings.ToUpper(colType.DatabaseTypeName()) {
|
||||||
|
case "NUMBER", "FLOAT", "BINARY_FLOAT", "BINARY_DOUBLE":
|
||||||
|
if _, scale, ok := colType.DecimalSize(); ok && scale == 0 {
|
||||||
|
// Scale is 0, treat it as an integer.
|
||||||
|
values[i] = new(sql.NullInt64)
|
||||||
|
} else {
|
||||||
|
// Scale is non-zero or unknown, treat
|
||||||
|
// it as a float.
|
||||||
|
values[i] = new(sql.NullFloat64)
|
||||||
|
}
|
||||||
|
case "DATE", "TIMESTAMP", "TIMESTAMP WITH TIME ZONE", "TIMESTAMP WITH LOCAL TIME ZONE":
|
||||||
|
values[i] = new(sql.NullTime)
|
||||||
|
case "JSON":
|
||||||
|
values[i] = new(sql.RawBytes)
|
||||||
|
default:
|
||||||
|
values[i] = new(sql.NullString)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Scan(values...); err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to scan row: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
vMap := make(map[string]any)
|
||||||
|
for i, col := range cols {
|
||||||
|
receiver := values[i]
|
||||||
|
|
||||||
|
switch v := receiver.(type) {
|
||||||
|
case *sql.NullInt64:
|
||||||
|
if v.Valid {
|
||||||
|
vMap[col] = v.Int64
|
||||||
|
} else {
|
||||||
|
vMap[col] = nil
|
||||||
|
}
|
||||||
|
case *sql.NullFloat64:
|
||||||
|
if v.Valid {
|
||||||
|
vMap[col] = v.Float64
|
||||||
|
} else {
|
||||||
|
vMap[col] = nil
|
||||||
|
}
|
||||||
|
case *sql.NullString:
|
||||||
|
if v.Valid {
|
||||||
|
vMap[col] = v.String
|
||||||
|
} else {
|
||||||
|
vMap[col] = nil
|
||||||
|
}
|
||||||
|
case *sql.NullTime:
|
||||||
|
if v.Valid {
|
||||||
|
vMap[col] = v.Time
|
||||||
|
} else {
|
||||||
|
vMap[col] = nil
|
||||||
|
}
|
||||||
|
case *sql.RawBytes:
|
||||||
|
if *v != nil {
|
||||||
|
var unmarshaledData any
|
||||||
|
if err := json.Unmarshal(*v, &unmarshaledData); err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to unmarshal json data for column %s", col)
|
||||||
|
}
|
||||||
|
vMap[col] = unmarshaledData
|
||||||
|
} else {
|
||||||
|
vMap[col] = nil
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unexpected receiver type: %T", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out = append(out, vMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("errors encountered during query execution or row processing: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Config) (*sql.DB, error) {
|
func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Config) (*sql.DB, error) {
|
||||||
//nolint:all // Reassigned ctx
|
//nolint:all // Reassigned ctx
|
||||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, config.Name)
|
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, config.Name)
|
||||||
|
|||||||
@@ -152,3 +152,50 @@ func (s *Source) ToConfig() sources.SourceConfig {
|
|||||||
func (s *Source) RedisClient() RedisClient {
|
func (s *Source) RedisClient() RedisClient {
|
||||||
return s.Client
|
return s.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Source) RunCommand(ctx context.Context, cmds [][]any) (any, error) {
|
||||||
|
// Execute commands
|
||||||
|
responses := make([]*redis.Cmd, len(cmds))
|
||||||
|
for i, cmd := range cmds {
|
||||||
|
responses[i] = s.RedisClient().Do(ctx, cmd...)
|
||||||
|
}
|
||||||
|
// Parse responses
|
||||||
|
out := make([]any, len(cmds))
|
||||||
|
for i, resp := range responses {
|
||||||
|
if err := resp.Err(); err != nil {
|
||||||
|
// Add error from each command to `errSum`
|
||||||
|
errString := fmt.Sprintf("error from executing command at index %d: %s", i, err)
|
||||||
|
out[i] = errString
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
val, err := resp.Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error getting result: %s", err)
|
||||||
|
}
|
||||||
|
out[i] = convertRedisResult(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertRedisResult recursively converts redis results (map[any]any) to be
|
||||||
|
// JSON-marshallable (map[string]any).
|
||||||
|
// It converts map[any]any to map[string]any and handles nested structures.
|
||||||
|
func convertRedisResult(v any) any {
|
||||||
|
switch val := v.(type) {
|
||||||
|
case map[any]any:
|
||||||
|
m := make(map[string]any)
|
||||||
|
for k, v := range val {
|
||||||
|
m[fmt.Sprint(k)] = convertRedisResult(v)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
case []any:
|
||||||
|
s := make([]any, len(val))
|
||||||
|
for i, v := range val {
|
||||||
|
s[i] = convertRedisResult(v)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
default:
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import (
|
|||||||
_ "github.com/go-sql-driver/mysql"
|
_ "github.com/go-sql-driver/mysql"
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon"
|
||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -106,6 +107,59 @@ func (s *Source) SingleStorePool() *sql.DB {
|
|||||||
return s.Pool
|
return s.Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) {
|
||||||
|
results, err := s.SingleStorePool().QueryContext(ctx, statement, params...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cols, err := results.Columns()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create an array of values for each column, which can be re-used to scan each row
|
||||||
|
rawValues := make([]any, len(cols))
|
||||||
|
values := make([]any, len(cols))
|
||||||
|
for i := range rawValues {
|
||||||
|
values[i] = &rawValues[i]
|
||||||
|
}
|
||||||
|
defer results.Close()
|
||||||
|
|
||||||
|
colTypes, err := results.ColumnTypes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to get column types: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var out []any
|
||||||
|
for results.Next() {
|
||||||
|
err := results.Scan(values...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||||
|
}
|
||||||
|
vMap := make(map[string]any)
|
||||||
|
for i, name := range cols {
|
||||||
|
val := rawValues[i]
|
||||||
|
if val == nil {
|
||||||
|
vMap[name] = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("errors encountered when converting values: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out = append(out, vMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := results.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
func initSingleStoreConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
|
func initSingleStoreConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
|
||||||
//nolint:all // Reassigned ctx
|
//nolint:all // Reassigned ctx
|
||||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||||
|
|||||||
@@ -16,9 +16,7 @@ package cloudmonitoring
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
@@ -44,9 +42,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
BaseURL() string
|
|
||||||
Client() *http.Client
|
Client() *http.Client
|
||||||
UserAgent() string
|
RunQuery(projectID, query string) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -110,45 +107,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("query parameter not found or not a string")
|
return nil, fmt.Errorf("query parameter not found or not a string")
|
||||||
}
|
}
|
||||||
|
return source.RunQuery(projectID, query)
|
||||||
url := fmt.Sprintf("%s/v1/projects/%s/location/global/prometheus/api/v1/query", source.BaseURL(), projectID)
|
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
q := req.URL.Query()
|
|
||||||
q.Add("query", query)
|
|
||||||
req.URL.RawQuery = q.Encode()
|
|
||||||
|
|
||||||
req.Header.Set("User-Agent", source.UserAgent())
|
|
||||||
|
|
||||||
resp, err := source.Client().Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return nil, fmt.Errorf("request failed: %s, body: %s", resp.Status, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(body) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var result map[string]any
|
|
||||||
if err := json.Unmarshal(body, &result); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to unmarshal json: %w, body: %s", err, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ import (
|
|||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -44,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
MindsDBPool() *sql.DB
|
MindsDBPool() *sql.DB
|
||||||
|
RunSQL(context.Context, string, []any) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -108,57 +108,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"])
|
return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"])
|
||||||
}
|
}
|
||||||
|
|
||||||
results, err := source.MindsDBPool().QueryContext(ctx, sql)
|
return source.RunSQL(ctx, sql, nil)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
|
||||||
}
|
|
||||||
defer results.Close()
|
|
||||||
|
|
||||||
cols, err := results.Columns()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// create an array of values for each column, which can be re-used to scan each row
|
|
||||||
rawValues := make([]any, len(cols))
|
|
||||||
values := make([]any, len(cols))
|
|
||||||
for i := range rawValues {
|
|
||||||
values[i] = &rawValues[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
colTypes, err := results.ColumnTypes()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to get column types: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var out []any
|
|
||||||
for results.Next() {
|
|
||||||
err := results.Scan(values...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
|
||||||
}
|
|
||||||
vMap := make(map[string]any)
|
|
||||||
for i, name := range cols {
|
|
||||||
val := rawValues[i]
|
|
||||||
if val == nil {
|
|
||||||
vMap[name] = nil
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// MindsDB uses mysql driver
|
|
||||||
vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("errors encountered when converting values: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out = append(out, vMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := results.Err(); err != nil {
|
|
||||||
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return out, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ import (
|
|||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -44,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
MindsDBPool() *sql.DB
|
MindsDBPool() *sql.DB
|
||||||
|
RunSQL(context.Context, string, []any) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -116,59 +116,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
}
|
}
|
||||||
|
|
||||||
sliceParams := newParams.AsSlice()
|
sliceParams := newParams.AsSlice()
|
||||||
|
return source.RunSQL(ctx, newStatement, sliceParams)
|
||||||
// MindsDB now supports MySQL prepared statements natively
|
|
||||||
results, err := source.MindsDBPool().QueryContext(ctx, newStatement, sliceParams...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cols, err := results.Columns()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// create an array of values for each column, which can be re-used to scan each row
|
|
||||||
rawValues := make([]any, len(cols))
|
|
||||||
values := make([]any, len(cols))
|
|
||||||
for i := range rawValues {
|
|
||||||
values[i] = &rawValues[i]
|
|
||||||
}
|
|
||||||
defer results.Close()
|
|
||||||
|
|
||||||
colTypes, err := results.ColumnTypes()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to get column types: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var out []any
|
|
||||||
for results.Next() {
|
|
||||||
err := results.Scan(values...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
|
||||||
}
|
|
||||||
vMap := make(map[string]any)
|
|
||||||
for i, name := range cols {
|
|
||||||
val := rawValues[i]
|
|
||||||
if val == nil {
|
|
||||||
vMap[name] = nil
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// MindsDB uses mysql driver
|
|
||||||
vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("errors encountered when converting values: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out = append(out, vMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := results.Err(); err != nil {
|
|
||||||
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return out, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ import (
|
|||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -36,6 +35,7 @@ func init() {
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
OceanBasePool() *sql.DB
|
OceanBasePool() *sql.DB
|
||||||
|
RunSQL(context.Context, string, []any) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -99,58 +99,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("unable to get cast %s", sliceParams[0])
|
return nil, fmt.Errorf("unable to get cast %s", sliceParams[0])
|
||||||
}
|
}
|
||||||
|
return source.RunSQL(ctx, sqlStr, nil)
|
||||||
results, err := source.OceanBasePool().QueryContext(ctx, sqlStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
|
||||||
}
|
|
||||||
defer results.Close()
|
|
||||||
|
|
||||||
cols, err := results.Columns()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// create an array of values for each column, which can be re-used to scan each row
|
|
||||||
rawValues := make([]any, len(cols))
|
|
||||||
values := make([]any, len(cols))
|
|
||||||
for i := range rawValues {
|
|
||||||
values[i] = &rawValues[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
colTypes, err := results.ColumnTypes()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to get column types: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var out []any
|
|
||||||
for results.Next() {
|
|
||||||
err := results.Scan(values...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
|
||||||
}
|
|
||||||
vMap := make(map[string]any)
|
|
||||||
for i, name := range cols {
|
|
||||||
val := rawValues[i]
|
|
||||||
if val == nil {
|
|
||||||
vMap[name] = nil
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// oceanbase uses mysql driver
|
|
||||||
vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("errors encountered when converting values: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out = append(out, vMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := results.Err(); err != nil {
|
|
||||||
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return out, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseParams parses the input parameters for the tool.
|
// ParseParams parses the input parameters for the tool.
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ import (
|
|||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -36,6 +35,7 @@ func init() {
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
OceanBasePool() *sql.DB
|
OceanBasePool() *sql.DB
|
||||||
|
RunSQL(context.Context, string, []any) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -109,59 +109,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sliceParams := newParams.AsSlice()
|
sliceParams := newParams.AsSlice()
|
||||||
results, err := source.OceanBasePool().QueryContext(ctx, newStatement, sliceParams...)
|
return source.RunSQL(ctx, newStatement, sliceParams)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cols, err := results.Columns()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// create an array of values for each column, which can be re-used to scan each row
|
|
||||||
rawValues := make([]any, len(cols))
|
|
||||||
values := make([]any, len(cols))
|
|
||||||
for i := range rawValues {
|
|
||||||
values[i] = &rawValues[i]
|
|
||||||
}
|
|
||||||
defer results.Close()
|
|
||||||
|
|
||||||
colTypes, err := results.ColumnTypes()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to get column types: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var out []any
|
|
||||||
for results.Next() {
|
|
||||||
err := results.Scan(values...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
|
||||||
}
|
|
||||||
vMap := make(map[string]any)
|
|
||||||
for i, name := range cols {
|
|
||||||
val := rawValues[i]
|
|
||||||
if val == nil {
|
|
||||||
vMap[name] = nil
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// oceanbase uses mysql driver
|
|
||||||
vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("errors encountered when converting values: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out = append(out, vMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := results.Err(); err != nil {
|
|
||||||
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return out, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseParams parses the input parameters for the tool.
|
// ParseParams parses the input parameters for the tool.
|
||||||
|
|||||||
@@ -5,9 +5,7 @@ package oracleexecutesql
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
@@ -34,6 +32,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
OracleDB() *sql.DB
|
OracleDB() *sql.DB
|
||||||
|
RunSQL(context.Context, string, []any) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -95,107 +94,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, fmt.Errorf("error getting logger: %s", err)
|
return nil, fmt.Errorf("error getting logger: %s", err)
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sqlParam)
|
logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sqlParam)
|
||||||
|
return source.RunSQL(ctx, sqlParam, nil)
|
||||||
results, err := source.OracleDB().QueryContext(ctx, sqlParam)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
|
||||||
}
|
|
||||||
defer results.Close()
|
|
||||||
|
|
||||||
// If Columns() errors, it might be a DDL/DML without an OUTPUT clause.
|
|
||||||
// We proceed, and results.Err() will catch actual query execution errors.
|
|
||||||
// 'out' will remain nil if cols is empty or err is not nil here.
|
|
||||||
cols, _ := results.Columns()
|
|
||||||
|
|
||||||
// Get Column types
|
|
||||||
colTypes, err := results.ColumnTypes()
|
|
||||||
if err != nil {
|
|
||||||
if err := results.Err(); err != nil {
|
|
||||||
return nil, fmt.Errorf("query execution error: %w", err)
|
|
||||||
}
|
|
||||||
return []any{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var out []any
|
|
||||||
for results.Next() {
|
|
||||||
// Create slice to hold values
|
|
||||||
values := make([]any, len(cols))
|
|
||||||
for i, colType := range colTypes {
|
|
||||||
// Based on the database type, we prepare a pointer to a Go type.
|
|
||||||
switch strings.ToUpper(colType.DatabaseTypeName()) {
|
|
||||||
case "NUMBER", "FLOAT", "BINARY_FLOAT", "BINARY_DOUBLE":
|
|
||||||
if _, scale, ok := colType.DecimalSize(); ok && scale == 0 {
|
|
||||||
// Scale is 0, treat as an integer.
|
|
||||||
values[i] = new(sql.NullInt64)
|
|
||||||
} else {
|
|
||||||
// Scale is non-zero or unknown, treat as a float.
|
|
||||||
values[i] = new(sql.NullFloat64)
|
|
||||||
}
|
|
||||||
case "DATE", "TIMESTAMP", "TIMESTAMP WITH TIME ZONE", "TIMESTAMP WITH LOCAL TIME ZONE":
|
|
||||||
values[i] = new(sql.NullTime)
|
|
||||||
case "JSON":
|
|
||||||
values[i] = new(sql.RawBytes)
|
|
||||||
default:
|
|
||||||
values[i] = new(sql.NullString)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := results.Scan(values...); err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to scan row: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
vMap := make(map[string]any)
|
|
||||||
for i, col := range cols {
|
|
||||||
receiver := values[i]
|
|
||||||
|
|
||||||
// Dereference the pointer and check for validity (not NULL).
|
|
||||||
switch v := receiver.(type) {
|
|
||||||
case *sql.NullInt64:
|
|
||||||
if v.Valid {
|
|
||||||
vMap[col] = v.Int64
|
|
||||||
} else {
|
|
||||||
vMap[col] = nil
|
|
||||||
}
|
|
||||||
case *sql.NullFloat64:
|
|
||||||
if v.Valid {
|
|
||||||
vMap[col] = v.Float64
|
|
||||||
} else {
|
|
||||||
vMap[col] = nil
|
|
||||||
}
|
|
||||||
case *sql.NullString:
|
|
||||||
if v.Valid {
|
|
||||||
vMap[col] = v.String
|
|
||||||
} else {
|
|
||||||
vMap[col] = nil
|
|
||||||
}
|
|
||||||
case *sql.NullTime:
|
|
||||||
if v.Valid {
|
|
||||||
vMap[col] = v.Time
|
|
||||||
} else {
|
|
||||||
vMap[col] = nil
|
|
||||||
}
|
|
||||||
case *sql.RawBytes:
|
|
||||||
if *v != nil {
|
|
||||||
var unmarshaledData any
|
|
||||||
if err := json.Unmarshal(*v, &unmarshaledData); err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to unmarshal json data for column %s", col)
|
|
||||||
}
|
|
||||||
vMap[col] = unmarshaledData
|
|
||||||
} else {
|
|
||||||
vMap[col] = nil
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unexpected receiver type: %T", v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out = append(out, vMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := results.Err(); err != nil {
|
|
||||||
return nil, fmt.Errorf("errors encountered during query execution or row processing: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return out, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -5,9 +5,7 @@ package oraclesql
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
@@ -33,6 +31,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
OracleDB() *sql.DB
|
OracleDB() *sql.DB
|
||||||
|
RunSQL(context.Context, string, []any) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -103,99 +102,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
fmt.Printf("[%d]=%T ", i, p)
|
fmt.Printf("[%d]=%T ", i, p)
|
||||||
}
|
}
|
||||||
fmt.Printf("\n")
|
fmt.Printf("\n")
|
||||||
|
return source.RunSQL(ctx, newStatement, sliceParams)
|
||||||
rows, err := source.OracleDB().QueryContext(ctx, newStatement, sliceParams...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
cols, _ := rows.Columns()
|
|
||||||
|
|
||||||
// Get Column types
|
|
||||||
colTypes, err := rows.ColumnTypes()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to get column types: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var out []any
|
|
||||||
for rows.Next() {
|
|
||||||
values := make([]any, len(cols))
|
|
||||||
for i, colType := range colTypes {
|
|
||||||
switch strings.ToUpper(colType.DatabaseTypeName()) {
|
|
||||||
case "NUMBER", "FLOAT", "BINARY_FLOAT", "BINARY_DOUBLE":
|
|
||||||
if _, scale, ok := colType.DecimalSize(); ok && scale == 0 {
|
|
||||||
// Scale is 0, treat it as an integer.
|
|
||||||
values[i] = new(sql.NullInt64)
|
|
||||||
} else {
|
|
||||||
// Scale is non-zero or unknown, treat
|
|
||||||
// it as a float.
|
|
||||||
values[i] = new(sql.NullFloat64)
|
|
||||||
}
|
|
||||||
case "DATE", "TIMESTAMP", "TIMESTAMP WITH TIME ZONE", "TIMESTAMP WITH LOCAL TIME ZONE":
|
|
||||||
values[i] = new(sql.NullTime)
|
|
||||||
case "JSON":
|
|
||||||
values[i] = new(sql.RawBytes)
|
|
||||||
default:
|
|
||||||
values[i] = new(sql.NullString)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rows.Scan(values...); err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to scan row: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
vMap := make(map[string]any)
|
|
||||||
for i, col := range cols {
|
|
||||||
receiver := values[i]
|
|
||||||
|
|
||||||
switch v := receiver.(type) {
|
|
||||||
case *sql.NullInt64:
|
|
||||||
if v.Valid {
|
|
||||||
vMap[col] = v.Int64
|
|
||||||
} else {
|
|
||||||
vMap[col] = nil
|
|
||||||
}
|
|
||||||
case *sql.NullFloat64:
|
|
||||||
if v.Valid {
|
|
||||||
vMap[col] = v.Float64
|
|
||||||
} else {
|
|
||||||
vMap[col] = nil
|
|
||||||
}
|
|
||||||
case *sql.NullString:
|
|
||||||
if v.Valid {
|
|
||||||
vMap[col] = v.String
|
|
||||||
} else {
|
|
||||||
vMap[col] = nil
|
|
||||||
}
|
|
||||||
case *sql.NullTime:
|
|
||||||
if v.Valid {
|
|
||||||
vMap[col] = v.Time
|
|
||||||
} else {
|
|
||||||
vMap[col] = nil
|
|
||||||
}
|
|
||||||
case *sql.RawBytes:
|
|
||||||
if *v != nil {
|
|
||||||
var unmarshaledData any
|
|
||||||
if err := json.Unmarshal(*v, &unmarshaledData); err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to unmarshal json data for column %s", col)
|
|
||||||
}
|
|
||||||
vMap[col] = unmarshaledData
|
|
||||||
} else {
|
|
||||||
vMap[col] = nil
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unexpected receiver type: %T", v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out = append(out, vMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rows.Err(); err != nil {
|
|
||||||
return nil, fmt.Errorf("errors encountered during query execution or row processing: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return out, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -22,8 +22,6 @@ import (
|
|||||||
redissrc "github.com/googleapis/genai-toolbox/internal/sources/redis"
|
redissrc "github.com/googleapis/genai-toolbox/internal/sources/redis"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
jsoniter "github.com/json-iterator/go"
|
|
||||||
"github.com/redis/go-redis/v9"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const kind string = "redis"
|
const kind string = "redis"
|
||||||
@@ -44,6 +42,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
RedisClient() redissrc.RedisClient
|
RedisClient() redissrc.RedisClient
|
||||||
|
RunCommand(context.Context, [][]any) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -94,44 +93,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error replacing commands' parameters: %s", err)
|
return nil, fmt.Errorf("error replacing commands' parameters: %s", err)
|
||||||
}
|
}
|
||||||
|
return source.RunCommand(ctx, cmds)
|
||||||
// Execute commands
|
|
||||||
responses := make([]*redis.Cmd, len(cmds))
|
|
||||||
for i, cmd := range cmds {
|
|
||||||
responses[i] = source.RedisClient().Do(ctx, cmd...)
|
|
||||||
}
|
|
||||||
// Parse responses
|
|
||||||
out := make([]any, len(t.Commands))
|
|
||||||
for i, resp := range responses {
|
|
||||||
if err := resp.Err(); err != nil {
|
|
||||||
// Add error from each command to `errSum`
|
|
||||||
errString := fmt.Sprintf("error from executing command at index %d: %s", i, err)
|
|
||||||
out[i] = errString
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
val, err := resp.Result()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error getting result: %s", err)
|
|
||||||
}
|
|
||||||
// If result is a map, convert map[any]any to map[string]any
|
|
||||||
// Because the Go's built-in json/encoding marshalling doesn't support
|
|
||||||
// map[any]any as an input
|
|
||||||
var strMap map[string]any
|
|
||||||
var json = jsoniter.ConfigCompatibleWithStandardLibrary
|
|
||||||
mapStr, err := json.Marshal(val)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error marshalling result: %s", err)
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(mapStr, &strMap)
|
|
||||||
if err != nil {
|
|
||||||
// result is not a map
|
|
||||||
out[i] = val
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
out[i] = strMap
|
|
||||||
}
|
|
||||||
|
|
||||||
return out, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ import (
|
|||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
@@ -45,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
SingleStorePool() *sql.DB
|
SingleStorePool() *sql.DB
|
||||||
|
RunSQL(context.Context, string, []any) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config represents the configuration for the singlestore-execute-sql tool.
|
// Config represents the configuration for the singlestore-execute-sql tool.
|
||||||
@@ -115,57 +115,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, fmt.Errorf("error getting logger: %s", err)
|
return nil, fmt.Errorf("error getting logger: %s", err)
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sql)
|
logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sql)
|
||||||
|
return source.RunSQL(ctx, sql, nil)
|
||||||
results, err := source.SingleStorePool().QueryContext(ctx, sql)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
|
||||||
}
|
|
||||||
defer results.Close()
|
|
||||||
|
|
||||||
cols, err := results.Columns()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// create an array of values for each column, which can be re-used to scan each row
|
|
||||||
rawValues := make([]any, len(cols))
|
|
||||||
values := make([]any, len(cols))
|
|
||||||
for i := range rawValues {
|
|
||||||
values[i] = &rawValues[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
colTypes, err := results.ColumnTypes()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to get column types: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var out []any
|
|
||||||
for results.Next() {
|
|
||||||
err := results.Scan(values...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
|
||||||
}
|
|
||||||
vMap := make(map[string]any)
|
|
||||||
for i, name := range cols {
|
|
||||||
val := rawValues[i]
|
|
||||||
if val == nil {
|
|
||||||
vMap[name] = nil
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("errors encountered when converting values: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out = append(out, vMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := results.Err(); err != nil {
|
|
||||||
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return out, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ import (
|
|||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -44,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
SingleStorePool() *sql.DB
|
SingleStorePool() *sql.DB
|
||||||
|
RunSQL(context.Context, string, []any) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config defines the configuration for a SingleStore SQL tool.
|
// Config defines the configuration for a SingleStore SQL tool.
|
||||||
@@ -143,56 +143,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
}
|
}
|
||||||
|
|
||||||
sliceParams := newParams.AsSlice()
|
sliceParams := newParams.AsSlice()
|
||||||
results, err := source.SingleStorePool().QueryContext(ctx, newStatement, sliceParams...)
|
return source.RunSQL(ctx, newStatement, sliceParams)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cols, err := results.Columns()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// create an array of values for each column, which can be re-used to scan each row
|
|
||||||
rawValues := make([]any, len(cols))
|
|
||||||
values := make([]any, len(cols))
|
|
||||||
for i := range rawValues {
|
|
||||||
values[i] = &rawValues[i]
|
|
||||||
}
|
|
||||||
defer results.Close()
|
|
||||||
|
|
||||||
colTypes, err := results.ColumnTypes()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to get column types: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var out []any
|
|
||||||
for results.Next() {
|
|
||||||
err := results.Scan(values...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
|
||||||
}
|
|
||||||
vMap := make(map[string]any)
|
|
||||||
for i, name := range cols {
|
|
||||||
val := rawValues[i]
|
|
||||||
if val == nil {
|
|
||||||
vMap[name] = nil
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("errors encountered when converting values: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out = append(out, vMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := results.Err(); err != nil {
|
|
||||||
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return out, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user