mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-11 16:38:15 -05:00
Compare commits
1 Commits
source-imp
...
invoke-int
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e1ac73c0c |
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
@@ -92,8 +93,29 @@ func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) PostgresPool() *pgxpool.Pool {
|
||||
return s.Pool
|
||||
func (s *Source) RunSQL(ctx context.Context, statement string, params any) (any, error) {
|
||||
sliceParams := params.(tools.ParamValues).AsSlice()
|
||||
results, err := s.Pool.Query(ctx, statement, sliceParams...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
fields := results.FieldDescriptions()
|
||||
|
||||
var out []any
|
||||
for results.Next() {
|
||||
v, err := results.Values()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
vMap := make(map[string]any)
|
||||
for i, f := range fields {
|
||||
vMap[f.Name] = v[i]
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string) (*pgxpool.Pool, error) {
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/postgres"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
const kind string = "postgres-sql"
|
||||
@@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
PostgresPool() *pgxpool.Pool
|
||||
RunSQL(ctx context.Context, statement string, params any) (any, error)
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
@@ -101,7 +100,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
AllParams: allParameters,
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Pool: s.PostgresPool(),
|
||||
source: s,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -119,7 +118,7 @@ type Tool struct {
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
Pool *pgxpool.Pool
|
||||
source compatibleSource
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
@@ -136,28 +135,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
sliceParams := newParams.AsSlice()
|
||||
results, err := t.Pool.Query(ctx, newStatement, sliceParams...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
fields := results.FieldDescriptions()
|
||||
|
||||
var out []any
|
||||
for results.Next() {
|
||||
v, err := results.Values()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
vMap := make(map[string]any)
|
||||
for i, f := range fields {
|
||||
vMap[f.Name] = v[i]
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
return t.source.RunSQL(ctx, newStatement, newParams)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
|
||||
Reference in New Issue
Block a user