Compare commits

...

1 Commits

Author SHA1 Message Date
Yuan Teoh
2e1ac73c0c draft: add invoke interface to postgres source 2025-10-06 11:22:02 -07:00
2 changed files with 28 additions and 28 deletions

View File

@@ -23,6 +23,7 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/jackc/pgx/v5/pgxpool"
"go.opentelemetry.io/otel/trace"
)
@@ -92,8 +93,29 @@ func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) PostgresPool() *pgxpool.Pool {
return s.Pool
func (s *Source) RunSQL(ctx context.Context, statement string, params any) (any, error) {
sliceParams := params.(tools.ParamValues).AsSlice()
results, err := s.Pool.Query(ctx, statement, sliceParams...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
fields := results.FieldDescriptions()
var out []any
for results.Next() {
v, err := results.Values()
if err != nil {
return nil, fmt.Errorf("unable to parse row: %w", err)
}
vMap := make(map[string]any)
for i, f := range fields {
vMap[f.Name] = v[i]
}
out = append(out, vMap)
}
return out, nil
}
func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string) (*pgxpool.Pool, error) {

View File

@@ -24,7 +24,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
"github.com/googleapis/genai-toolbox/internal/sources/postgres"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/jackc/pgx/v5/pgxpool"
)
const kind string = "postgres-sql"
@@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
PostgresPool() *pgxpool.Pool
RunSQL(ctx context.Context, statement string, params any) (any, error)
}
// validate compatible sources are still compatible
@@ -101,7 +100,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
AllParams: allParameters,
Statement: cfg.Statement,
AuthRequired: cfg.AuthRequired,
Pool: s.PostgresPool(),
source: s,
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
@@ -119,7 +118,7 @@ type Tool struct {
TemplateParameters tools.Parameters `yaml:"templateParameters"`
AllParams tools.Parameters `yaml:"allParams"`
Pool *pgxpool.Pool
source compatibleSource
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
@@ -136,28 +135,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
if err != nil {
return nil, fmt.Errorf("unable to extract standard params %w", err)
}
sliceParams := newParams.AsSlice()
results, err := t.Pool.Query(ctx, newStatement, sliceParams...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
fields := results.FieldDescriptions()
var out []any
for results.Next() {
v, err := results.Values()
if err != nil {
return nil, fmt.Errorf("unable to parse row: %w", err)
}
vMap := make(map[string]any)
for i, f := range fields {
vMap[f.Name] = v[i]
}
out = append(out, vMap)
}
return out, nil
return t.source.RunSQL(ctx, newStatement, newParams)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {