mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-04-09 03:02:26 -04:00
The `ParseParams` Tool interface is only passing the tool's parameter into a generic `parameters.ParseParams()` helper. Instead of keeping it as a tool interface, we add a `GetParameters()` method (https://github.com/googleapis/genai-toolbox/pull/2374) to the tool interface and call it directly from the API handlers. This way we keep the parameter parsing logic independent from the tools.
210 lines
6.6 KiB
Go
210 lines
6.6 KiB
Go
// Copyright 2026 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package postgresliststoredprocedure
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
yaml "github.com/goccy/go-yaml"
|
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
|
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
|
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
|
"github.com/googleapis/genai-toolbox/internal/sources/postgres"
|
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
const resourceType string = "postgres-list-stored-procedure"
|
|
|
|
const listStoredProcedure = `
|
|
SELECT
|
|
n.nspname AS schema_name,
|
|
p.proname AS name,
|
|
r.rolname AS owner,
|
|
l.lanname AS language,
|
|
pg_catalog.pg_get_functiondef(p.oid) AS definition,
|
|
pg_catalog.obj_description(p.oid, 'pg_proc') AS description
|
|
FROM pg_catalog.pg_proc p
|
|
JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
|
|
JOIN pg_catalog.pg_roles r ON r.oid = p.proowner
|
|
JOIN pg_catalog.pg_language l ON l.oid = p.prolang
|
|
WHERE
|
|
p.prokind = 'p' AND
|
|
($1::text IS NULL OR r.rolname LIKE '%' || $1::text || '%') AND
|
|
($2::text IS NULL OR n.nspname LIKE '%' || $2::text || '%')
|
|
ORDER BY n.nspname, p.proname
|
|
LIMIT
|
|
COALESCE($3::int, 20);
|
|
`
|
|
|
|
func init() {
|
|
if !tools.Register(resourceType, newConfig) {
|
|
panic(fmt.Sprintf("tool type %q already registered", resourceType))
|
|
}
|
|
}
|
|
|
|
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
|
actual := Config{Name: name}
|
|
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
|
return nil, err
|
|
}
|
|
return actual, nil
|
|
}
|
|
|
|
type compatibleSource interface {
|
|
PostgresPool() *pgxpool.Pool
|
|
}
|
|
|
|
// validate compatible sources are still compatible
|
|
var _ compatibleSource = &alloydbpg.Source{}
|
|
var _ compatibleSource = &cloudsqlpg.Source{}
|
|
var _ compatibleSource = &postgres.Source{}
|
|
|
|
var compatibleSources = [...]string{alloydbpg.SourceType, cloudsqlpg.SourceType, postgres.SourceType}
|
|
|
|
type Config struct {
|
|
Name string `yaml:"name" validate:"required"`
|
|
Type string `yaml:"type" validate:"required"`
|
|
Source string `yaml:"source" validate:"required"`
|
|
Description string `yaml:"description"`
|
|
AuthRequired []string `yaml:"authRequired"`
|
|
}
|
|
|
|
// validate interface
|
|
var _ tools.ToolConfig = Config{}
|
|
|
|
func (cfg Config) ToolConfigType() string {
|
|
return resourceType
|
|
}
|
|
|
|
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
|
// verify source exists
|
|
rawS, ok := srcs[cfg.Source]
|
|
if !ok {
|
|
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
|
}
|
|
|
|
// verify the source is compatible
|
|
s, ok := rawS.(compatibleSource)
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid source for %q tool: source type must be one of %q", resourceType, compatibleSources)
|
|
}
|
|
|
|
allParameters := parameters.Parameters{
|
|
parameters.NewStringParameterWithRequired("role_name", "Optional: The owner name to filter the stored procedures by. Defaults to NULL.", false),
|
|
parameters.NewStringParameterWithRequired("schema_name", "Optional: The schema name to filter the stored procedures by. Defaults to NULL.", false),
|
|
parameters.NewIntParameterWithDefault("limit", 20, "Optional: The maximum number of stored procedures to return. Defaults to 20."),
|
|
}
|
|
paramManifest := allParameters.Manifest()
|
|
|
|
if cfg.Description == "" {
|
|
cfg.Description = "Retrieves stored procedure metadata returning schema name, procedure name, procedure owner, language, definition, and description, filtered by optional role name (procedure owner), schema name, and limit (default 20)."
|
|
}
|
|
|
|
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
|
|
|
|
// finish tool setup
|
|
return Tool{
|
|
Config: cfg,
|
|
allParams: allParameters,
|
|
pool: s.PostgresPool(),
|
|
manifest: tools.Manifest{
|
|
Description: cfg.Description,
|
|
Parameters: paramManifest,
|
|
AuthRequired: cfg.AuthRequired,
|
|
},
|
|
mcpManifest: mcpManifest,
|
|
}, nil
|
|
}
|
|
|
|
// validate interface
|
|
var _ tools.Tool = Tool{}
|
|
|
|
type Tool struct {
|
|
Config
|
|
allParams parameters.Parameters `yaml:"allParams"`
|
|
pool *pgxpool.Pool
|
|
manifest tools.Manifest
|
|
mcpManifest tools.McpManifest
|
|
}
|
|
|
|
func (t Tool) ToConfig() tools.ToolConfig {
|
|
return t.Config
|
|
}
|
|
|
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
|
paramsMap := params.AsMap()
|
|
|
|
newParams, err := parameters.GetParams(t.allParams, paramsMap)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
|
}
|
|
sliceParams := newParams.AsSlice()
|
|
|
|
results, err := t.pool.Query(ctx, listStoredProcedure, sliceParams...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to execute query: %w", err)
|
|
}
|
|
defer results.Close()
|
|
|
|
fields := results.FieldDescriptions()
|
|
var out []map[string]any
|
|
|
|
for results.Next() {
|
|
values, err := results.Values()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to parse row: %w", err)
|
|
}
|
|
rowMap := make(map[string]any)
|
|
for i, field := range fields {
|
|
rowMap[string(field.Name)] = values[i]
|
|
}
|
|
out = append(out, rowMap)
|
|
}
|
|
|
|
return out, nil
|
|
}
|
|
|
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
|
return parameters.EmbedParams(ctx, t.allParams, paramValues, embeddingModelsMap, nil)
|
|
}
|
|
|
|
func (t Tool) Manifest() tools.Manifest {
|
|
return t.manifest
|
|
}
|
|
|
|
func (t Tool) McpManifest() tools.McpManifest {
|
|
return t.mcpManifest
|
|
}
|
|
|
|
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
|
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
|
}
|
|
|
|
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
|
return false, nil
|
|
}
|
|
|
|
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
|
return "Authorization", nil
|
|
}
|
|
|
|
func (t Tool) GetParameters() parameters.Parameters {
|
|
return t.allParams
|
|
}
|