mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-14 09:57:58 -05:00
Compare commits
1 Commits
host-error
...
invoke-int
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e1ac73c0c |
@@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace"
|
||||||
)
|
)
|
||||||
@@ -92,8 +93,29 @@ func (s *Source) SourceKind() string {
|
|||||||
return SourceKind
|
return SourceKind
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Source) PostgresPool() *pgxpool.Pool {
|
func (s *Source) RunSQL(ctx context.Context, statement string, params any) (any, error) {
|
||||||
return s.Pool
|
sliceParams := params.(tools.ParamValues).AsSlice()
|
||||||
|
results, err := s.Pool.Query(ctx, statement, sliceParams...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := results.FieldDescriptions()
|
||||||
|
|
||||||
|
var out []any
|
||||||
|
for results.Next() {
|
||||||
|
v, err := results.Values()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||||
|
}
|
||||||
|
vMap := make(map[string]any)
|
||||||
|
for i, f := range fields {
|
||||||
|
vMap[f.Name] = v[i]
|
||||||
|
}
|
||||||
|
out = append(out, vMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string) (*pgxpool.Pool, error) {
|
func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string) (*pgxpool.Pool, error) {
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources/postgres"
|
"github.com/googleapis/genai-toolbox/internal/sources/postgres"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const kind string = "postgres-sql"
|
const kind string = "postgres-sql"
|
||||||
@@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
PostgresPool() *pgxpool.Pool
|
RunSQL(ctx context.Context, statement string, params any) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate compatible sources are still compatible
|
// validate compatible sources are still compatible
|
||||||
@@ -101,7 +100,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
AllParams: allParameters,
|
AllParams: allParameters,
|
||||||
Statement: cfg.Statement,
|
Statement: cfg.Statement,
|
||||||
AuthRequired: cfg.AuthRequired,
|
AuthRequired: cfg.AuthRequired,
|
||||||
Pool: s.PostgresPool(),
|
source: s,
|
||||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||||
mcpManifest: mcpManifest,
|
mcpManifest: mcpManifest,
|
||||||
}
|
}
|
||||||
@@ -119,7 +118,7 @@ type Tool struct {
|
|||||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||||
AllParams tools.Parameters `yaml:"allParams"`
|
AllParams tools.Parameters `yaml:"allParams"`
|
||||||
|
|
||||||
Pool *pgxpool.Pool
|
source compatibleSource
|
||||||
Statement string
|
Statement string
|
||||||
manifest tools.Manifest
|
manifest tools.Manifest
|
||||||
mcpManifest tools.McpManifest
|
mcpManifest tools.McpManifest
|
||||||
@@ -136,28 +135,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||||
}
|
}
|
||||||
sliceParams := newParams.AsSlice()
|
return t.source.RunSQL(ctx, newStatement, newParams)
|
||||||
results, err := t.Pool.Query(ctx, newStatement, sliceParams...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fields := results.FieldDescriptions()
|
|
||||||
|
|
||||||
var out []any
|
|
||||||
for results.Next() {
|
|
||||||
v, err := results.Values()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
|
||||||
}
|
|
||||||
vMap := make(map[string]any)
|
|
||||||
for i, f := range fields {
|
|
||||||
vMap[f.Name] = v[i]
|
|
||||||
}
|
|
||||||
out = append(out, vMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
return out, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user