Compare commits

...

1 Commits

Author SHA1 Message Date
Yuan Teoh
2e1ac73c0c draft: add invoke interface to postgres source 2025-10-06 11:22:02 -07:00
2 changed files with 28 additions and 28 deletions

View File

@@ -23,6 +23,7 @@ import (
"github.com/goccy/go-yaml" "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
"go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace"
) )
@@ -92,8 +93,29 @@ func (s *Source) SourceKind() string {
return SourceKind return SourceKind
} }
func (s *Source) PostgresPool() *pgxpool.Pool { func (s *Source) RunSQL(ctx context.Context, statement string, params any) (any, error) {
return s.Pool sliceParams := params.(tools.ParamValues).AsSlice()
results, err := s.Pool.Query(ctx, statement, sliceParams...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
fields := results.FieldDescriptions()
var out []any
for results.Next() {
v, err := results.Values()
if err != nil {
return nil, fmt.Errorf("unable to parse row: %w", err)
}
vMap := make(map[string]any)
for i, f := range fields {
vMap[f.Name] = v[i]
}
out = append(out, vMap)
}
return out, nil
} }
func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string) (*pgxpool.Pool, error) { func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string) (*pgxpool.Pool, error) {

View File

@@ -24,7 +24,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
"github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/sources/postgres"
"github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools"
"github.com/jackc/pgx/v5/pgxpool"
) )
const kind string = "postgres-sql" const kind string = "postgres-sql"
@@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
} }
type compatibleSource interface { type compatibleSource interface {
PostgresPool() *pgxpool.Pool RunSQL(ctx context.Context, statement string, params any) (any, error)
} }
// validate compatible sources are still compatible // validate compatible sources are still compatible
@@ -101,7 +100,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
AllParams: allParameters, AllParams: allParameters,
Statement: cfg.Statement, Statement: cfg.Statement,
AuthRequired: cfg.AuthRequired, AuthRequired: cfg.AuthRequired,
Pool: s.PostgresPool(), source: s,
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest, mcpManifest: mcpManifest,
} }
@@ -119,7 +118,7 @@ type Tool struct {
TemplateParameters tools.Parameters `yaml:"templateParameters"` TemplateParameters tools.Parameters `yaml:"templateParameters"`
AllParams tools.Parameters `yaml:"allParams"` AllParams tools.Parameters `yaml:"allParams"`
Pool *pgxpool.Pool source compatibleSource
Statement string Statement string
manifest tools.Manifest manifest tools.Manifest
mcpManifest tools.McpManifest mcpManifest tools.McpManifest
@@ -136,28 +135,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to extract standard params %w", err) return nil, fmt.Errorf("unable to extract standard params %w", err)
} }
sliceParams := newParams.AsSlice() return t.source.RunSQL(ctx, newStatement, newParams)
results, err := t.Pool.Query(ctx, newStatement, sliceParams...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
fields := results.FieldDescriptions()
var out []any
for results.Next() {
v, err := results.Values()
if err != nil {
return nil, fmt.Errorf("unable to parse row: %w", err)
}
vMap := make(map[string]any)
for i, f := range fields {
vMap[f.Name] = v[i]
}
out = append(out, vMap)
}
return out, nil
} }
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) { func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {