Files
genai-toolbox/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go
Yuan Teoh c1b87e209f refactor: move source implementation in Invoke() function to Source (#2240)
Move source-related queries from `Invoke()` function into Source.

The following sources are updated in this PR:
* alloydb-pg
* cloudsql-pg
* postgres

This is an effort to generalizing tools to work with any Source that
implements a specific interface. This will provide a better segregation
of the roles for Tools vs Source.

Tool's role will be limited to the following:
* Resolve any pre-implementation steps or parameters (e.g. template
parameters)
* Retrieving Source
* Calling the source's implementation
2025-12-31 04:46:18 +00:00

162 lines
4.6 KiB
Go

// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgreslistlocks
import (
"context"
"fmt"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"github.com/jackc/pgx/v5/pgxpool"
)
const kind string = "postgres-list-locks"
const listLocks = `
SELECT
locked.pid,
locked.usename,
locked.query,
string_agg(locked.transactionid::text,':') as trxid,
string_agg(locked.lockinfo,'||') as locks
FROM
(SELECT
a.pid,
a.usename,
a.query,
l.transactionid,
(l.granted::text||','||coalesce(l.relation::regclass,0)::text||','||l.mode::text)::text as lockinfo
FROM
pg_stat_activity a
JOIN pg_locks l ON l.pid = a.pid AND a.pid != pg_backend_pid()) as locked
GROUP BY
locked.pid, locked.usename, locked.query;
`
func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
}
}
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
return actual, nil
}
type compatibleSource interface {
PostgresPool() *pgxpool.Pool
RunSQL(context.Context, string, []any) (any, error)
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description"`
AuthRequired []string `yaml:"authRequired"`
}
// validate interface
var _ tools.ToolConfig = Config{}
func (cfg Config) ToolConfigKind() string {
return kind
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
allParameters := parameters.Parameters{}
paramManifest := allParameters.Manifest()
if cfg.Description == "" {
cfg.Description = "Identifies all locks held by active processes showing the process ID, user, query text, and an aggregated list of all transactions and specific locks (relation, mode, grant status) associated with each process."
}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
// finish tool setup
return Tool{
Config: cfg,
allParams: allParameters,
manifest: tools.Manifest{
Description: cfg.Description,
Parameters: paramManifest,
AuthRequired: cfg.AuthRequired,
},
mcpManifest: mcpManifest,
}, nil
}
// validate interface
var _ tools.Tool = Tool{}
type Tool struct {
Config
allParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
return t.Config
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
newParams, err := parameters.GetParams(t.allParams, paramsMap)
if err != nil {
return nil, fmt.Errorf("unable to extract standard params %w", err)
}
sliceParams := newParams.AsSlice()
return source.RunSQL(ctx, listLocks, sliceParams)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
return parameters.ParseParams(t.allParams, data, claims)
}
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}