mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-04-09 03:02:26 -04:00
422 lines
13 KiB
Go
422 lines
13 KiB
Go
// Copyright 2026 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package cockroachdb
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"math"
|
|
"net/url"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
|
|
crdbpgx "github.com/cockroachdb/cockroach-go/v2/crdb/crdbpgxv5"
|
|
"github.com/goccy/go-yaml"
|
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
|
"github.com/googleapis/genai-toolbox/internal/util"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
"go.opentelemetry.io/otel/trace"
|
|
)
|
|
|
|
const SourceType string = "cockroachdb"
|
|
|
|
var _ sources.SourceConfig = Config{}
|
|
|
|
func init() {
|
|
if !sources.Register(SourceType, newConfig) {
|
|
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
|
}
|
|
}
|
|
|
|
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
|
|
// MCP compliance: Read-only by default, require explicit opt-in for writes
|
|
actual := Config{
|
|
Name: name,
|
|
MaxRetries: 5,
|
|
RetryBaseDelay: "500ms",
|
|
ReadOnlyMode: true, // MCP requirement: read-only by default
|
|
EnableWriteMode: false, // Must be explicitly enabled
|
|
MaxRowLimit: 1000, // MCP requirement: limit query results
|
|
QueryTimeoutSec: 30, // MCP requirement: prevent long-running queries
|
|
EnableTelemetry: true, // MCP requirement: observability
|
|
TelemetryVerbose: false,
|
|
}
|
|
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Security validation: If EnableWriteMode is true, ReadOnlyMode should be false
|
|
if actual.EnableWriteMode {
|
|
actual.ReadOnlyMode = false
|
|
}
|
|
|
|
return actual, nil
|
|
}
|
|
|
|
type Config struct {
|
|
Name string `yaml:"name" validate:"required"`
|
|
Type string `yaml:"type" validate:"required"`
|
|
Host string `yaml:"host" validate:"required"`
|
|
Port string `yaml:"port" validate:"required"`
|
|
User string `yaml:"user" validate:"required"`
|
|
Password string `yaml:"password"`
|
|
Database string `yaml:"database" validate:"required"`
|
|
QueryParams map[string]string `yaml:"queryParams"`
|
|
MaxRetries int `yaml:"maxRetries"`
|
|
RetryBaseDelay string `yaml:"retryBaseDelay"`
|
|
|
|
// MCP Security Features
|
|
ReadOnlyMode bool `yaml:"readOnlyMode"` // Default: true (enforced in Initialize)
|
|
EnableWriteMode bool `yaml:"enableWriteMode"` // Explicit opt-in for write operations
|
|
MaxRowLimit int `yaml:"maxRowLimit"` // Default: 1000
|
|
QueryTimeoutSec int `yaml:"queryTimeoutSec"` // Default: 30
|
|
|
|
// Observability
|
|
EnableTelemetry bool `yaml:"enableTelemetry"` // Default: true
|
|
TelemetryVerbose bool `yaml:"telemetryVerbose"` // Default: false
|
|
ClusterID string `yaml:"clusterID"` // Optional cluster identifier for telemetry
|
|
}
|
|
|
|
func (r Config) SourceConfigType() string {
|
|
return SourceType
|
|
}
|
|
|
|
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
|
retryBaseDelay, err := time.ParseDuration(r.RetryBaseDelay)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid retryBaseDelay: %w", err)
|
|
}
|
|
|
|
pool, err := initCockroachDBConnectionPoolWithRetry(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Database, r.QueryParams, r.MaxRetries, retryBaseDelay)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to create pool: %w", err)
|
|
}
|
|
|
|
s := &Source{
|
|
Config: r,
|
|
Pool: pool,
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
var _ sources.Source = &Source{}
|
|
|
|
type Source struct {
|
|
Config
|
|
Pool *pgxpool.Pool
|
|
}
|
|
|
|
func (s *Source) SourceType() string {
|
|
return SourceType
|
|
}
|
|
|
|
func (s *Source) ToConfig() sources.SourceConfig {
|
|
return s.Config
|
|
}
|
|
|
|
func (s *Source) CockroachDBPool() *pgxpool.Pool {
|
|
return s.Pool
|
|
}
|
|
|
|
func (s *Source) PostgresPool() *pgxpool.Pool {
|
|
return s.Pool
|
|
}
|
|
|
|
// ExecuteTxWithRetry executes a function within a transaction with automatic retry logic
|
|
// using the official CockroachDB retry mechanism from cockroach-go/v2
|
|
func (s *Source) ExecuteTxWithRetry(ctx context.Context, fn func(pgx.Tx) error) error {
|
|
return crdbpgx.ExecuteTx(ctx, s.Pool, pgx.TxOptions{}, fn)
|
|
}
|
|
|
|
// Query executes a query using the connection pool with MCP security enforcement.
|
|
// For read-only queries, connection-level retry is sufficient.
|
|
// For write operations requiring transaction retry, use ExecuteTxWithRetry directly.
|
|
// Note: Callers should manage context timeouts as needed.
|
|
func (s *Source) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) {
|
|
// MCP Security Check 1: Enforce write operation restrictions
|
|
if err := s.CanExecuteWrite(sql); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// MCP Security Check 2: Apply query limits (row limit)
|
|
modifiedSQL, err := s.ApplyQueryLimits(sql)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return s.Pool.Query(ctx, modifiedSQL, args...)
|
|
}
|
|
|
|
// ============================================================================
|
|
// MCP Security & Observability Features
|
|
// ============================================================================
|
|
|
|
// TelemetryEvent represents a structured telemetry event for MCP tool calls
|
|
type TelemetryEvent struct {
|
|
Timestamp time.Time `json:"timestamp"`
|
|
ToolName string `json:"tool_name"`
|
|
ClusterID string `json:"cluster_id"`
|
|
Database string `json:"database"`
|
|
User string `json:"user"`
|
|
SQLRedacted string `json:"sql_redacted"` // Query with values redacted
|
|
Status string `json:"status"` // "success" | "failure"
|
|
ErrorCode string `json:"error_code,omitempty"`
|
|
ErrorMsg string `json:"error_msg,omitempty"`
|
|
LatencyMs int64 `json:"latency_ms"`
|
|
RowsAffected int64 `json:"rows_affected,omitempty"`
|
|
Metadata map[string]string `json:"metadata,omitempty"`
|
|
}
|
|
|
|
// StructuredError represents an MCP-compliant error with error codes
|
|
type StructuredError struct {
|
|
Code string `json:"error_code"`
|
|
Message string `json:"message"`
|
|
Details map[string]any `json:"details,omitempty"`
|
|
}
|
|
|
|
func (e *StructuredError) Error() string {
|
|
return fmt.Sprintf("[%s] %s", e.Code, e.Message)
|
|
}
|
|
|
|
// MCP Error Codes
|
|
const (
|
|
ErrCodeUnauthorized = "CRDB_UNAUTHORIZED"
|
|
ErrCodeReadOnlyViolation = "CRDB_READONLY_VIOLATION"
|
|
ErrCodeQueryTimeout = "CRDB_QUERY_TIMEOUT"
|
|
ErrCodeRowLimitExceeded = "CRDB_ROW_LIMIT_EXCEEDED"
|
|
ErrCodeInvalidSQL = "CRDB_INVALID_SQL"
|
|
ErrCodeConnectionFailed = "CRDB_CONNECTION_FAILED"
|
|
ErrCodeWriteModeRequired = "CRDB_WRITE_MODE_REQUIRED"
|
|
ErrCodeQueryExecutionFailed = "CRDB_QUERY_EXECUTION_FAILED"
|
|
)
|
|
|
|
// SQLStatementType represents the type of SQL statement
|
|
type SQLStatementType int
|
|
|
|
const (
|
|
SQLTypeUnknown SQLStatementType = iota
|
|
SQLTypeSelect
|
|
SQLTypeInsert
|
|
SQLTypeUpdate
|
|
SQLTypeDelete
|
|
SQLTypeDDL // CREATE, ALTER, DROP
|
|
SQLTypeTruncate
|
|
SQLTypeExplain
|
|
SQLTypeShow
|
|
SQLTypeSet
|
|
)
|
|
|
|
// ClassifySQL analyzes a SQL statement and returns its type
|
|
func ClassifySQL(sql string) SQLStatementType {
|
|
// Normalize: trim and convert to uppercase for analysis
|
|
normalized := strings.TrimSpace(strings.ToUpper(sql))
|
|
|
|
if normalized == "" {
|
|
return SQLTypeUnknown
|
|
}
|
|
|
|
// Remove comments
|
|
normalized = regexp.MustCompile(`--.*`).ReplaceAllString(normalized, "")
|
|
normalized = regexp.MustCompile(`/\*.*?\*/`).ReplaceAllString(normalized, "")
|
|
normalized = strings.TrimSpace(normalized)
|
|
|
|
// Check statement type
|
|
switch {
|
|
case strings.HasPrefix(normalized, "SELECT"):
|
|
return SQLTypeSelect
|
|
case strings.HasPrefix(normalized, "INSERT"):
|
|
return SQLTypeInsert
|
|
case strings.HasPrefix(normalized, "UPDATE"):
|
|
return SQLTypeUpdate
|
|
case strings.HasPrefix(normalized, "DELETE"):
|
|
return SQLTypeDelete
|
|
case strings.HasPrefix(normalized, "TRUNCATE"):
|
|
return SQLTypeTruncate
|
|
case strings.HasPrefix(normalized, "CREATE"):
|
|
return SQLTypeDDL
|
|
case strings.HasPrefix(normalized, "ALTER"):
|
|
return SQLTypeDDL
|
|
case strings.HasPrefix(normalized, "DROP"):
|
|
return SQLTypeDDL
|
|
case strings.HasPrefix(normalized, "EXPLAIN"):
|
|
return SQLTypeExplain
|
|
case strings.HasPrefix(normalized, "SHOW"):
|
|
return SQLTypeShow
|
|
case strings.HasPrefix(normalized, "SET"):
|
|
return SQLTypeSet
|
|
default:
|
|
return SQLTypeUnknown
|
|
}
|
|
}
|
|
|
|
// IsWriteOperation returns true if the SQL statement modifies data
|
|
func IsWriteOperation(sqlType SQLStatementType) bool {
|
|
switch sqlType {
|
|
case SQLTypeInsert, SQLTypeUpdate, SQLTypeDelete, SQLTypeTruncate, SQLTypeDDL:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// IsReadOnlyMode returns whether the source is in read-only mode
|
|
func (s *Source) IsReadOnlyMode() bool {
|
|
return s.ReadOnlyMode && !s.EnableWriteMode
|
|
}
|
|
|
|
// CanExecuteWrite checks if a write operation is allowed
|
|
func (s *Source) CanExecuteWrite(sql string) error {
|
|
sqlType := ClassifySQL(sql)
|
|
|
|
if IsWriteOperation(sqlType) && s.IsReadOnlyMode() {
|
|
return &StructuredError{
|
|
Code: ErrCodeReadOnlyViolation,
|
|
Message: "Write operations are not allowed in read-only mode. Set enableWriteMode: true to allow writes.",
|
|
Details: map[string]any{
|
|
"sql_type": sqlType,
|
|
"read_only_mode": s.ReadOnlyMode,
|
|
"enable_write_mode": s.EnableWriteMode,
|
|
},
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ApplyQueryLimits applies row limits to a SQL query for MCP security compliance.
|
|
// Context timeout management is the responsibility of the caller (following Go best practices).
|
|
// Returns potentially modified SQL with LIMIT clause for SELECT queries.
|
|
func (s *Source) ApplyQueryLimits(sql string) (string, error) {
|
|
sqlType := ClassifySQL(sql)
|
|
|
|
// Apply row limit only to SELECT queries
|
|
if sqlType == SQLTypeSelect && s.MaxRowLimit > 0 {
|
|
// Check if query already has LIMIT clause
|
|
normalized := strings.ToUpper(sql)
|
|
if !strings.Contains(normalized, " LIMIT ") {
|
|
// Add LIMIT clause - trim trailing whitespace and semicolon
|
|
sql = strings.TrimSpace(sql)
|
|
sql = strings.TrimSuffix(sql, ";")
|
|
sql = fmt.Sprintf("%s LIMIT %d", sql, s.MaxRowLimit)
|
|
}
|
|
}
|
|
|
|
return sql, nil
|
|
}
|
|
|
|
// RedactSQL redacts sensitive values from SQL for telemetry
|
|
func RedactSQL(sql string) string {
|
|
// Redact string literals
|
|
sql = regexp.MustCompile(`'[^']*'`).ReplaceAllString(sql, "'***'")
|
|
|
|
// Redact numbers that might be sensitive
|
|
sql = regexp.MustCompile(`\b\d{10,}\b`).ReplaceAllString(sql, "***")
|
|
|
|
return sql
|
|
}
|
|
|
|
// EmitTelemetry logs a telemetry event in structured JSON format
|
|
func (s *Source) EmitTelemetry(ctx context.Context, event TelemetryEvent) {
|
|
if !s.EnableTelemetry {
|
|
return
|
|
}
|
|
|
|
// Set cluster ID if not already set
|
|
if event.ClusterID == "" {
|
|
event.ClusterID = s.ClusterID
|
|
if event.ClusterID == "" {
|
|
event.ClusterID = s.Database // Fallback to database name
|
|
}
|
|
}
|
|
|
|
// Set database and user
|
|
if event.Database == "" {
|
|
event.Database = s.Database
|
|
}
|
|
if event.User == "" {
|
|
event.User = s.User
|
|
}
|
|
|
|
// Log as structured JSON
|
|
if s.TelemetryVerbose {
|
|
jsonBytes, _ := json.Marshal(event)
|
|
slog.Info("CockroachDB MCP Telemetry", "event", string(jsonBytes))
|
|
} else {
|
|
// Minimal logging
|
|
slog.Info("CockroachDB MCP",
|
|
"tool", event.ToolName,
|
|
"status", event.Status,
|
|
"latency_ms", event.LatencyMs,
|
|
"error_code", event.ErrorCode,
|
|
)
|
|
}
|
|
}
|
|
|
|
func initCockroachDBConnectionPoolWithRetry(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string, maxRetries int, baseDelay time.Duration) (*pgxpool.Pool, error) {
|
|
//nolint:all
|
|
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
|
defer span.End()
|
|
|
|
userAgent, err := util.UserAgentFromContext(ctx)
|
|
if err != nil {
|
|
userAgent = "genai-toolbox"
|
|
}
|
|
if queryParams == nil {
|
|
queryParams = make(map[string]string)
|
|
}
|
|
if _, ok := queryParams["application_name"]; !ok {
|
|
queryParams["application_name"] = userAgent
|
|
}
|
|
|
|
connURL := &url.URL{
|
|
Scheme: "postgres",
|
|
User: url.UserPassword(user, pass),
|
|
Host: fmt.Sprintf("%s:%s", host, port),
|
|
Path: dbname,
|
|
RawQuery: ConvertParamMapToRawQuery(queryParams),
|
|
}
|
|
|
|
var pool *pgxpool.Pool
|
|
for attempt := 0; attempt <= maxRetries; attempt++ {
|
|
pool, err = pgxpool.New(ctx, connURL.String())
|
|
if err == nil {
|
|
err = pool.Ping(ctx)
|
|
}
|
|
|
|
if err == nil {
|
|
return pool, nil
|
|
}
|
|
|
|
if attempt < maxRetries {
|
|
backoff := baseDelay * time.Duration(math.Pow(2, float64(attempt)))
|
|
time.Sleep(backoff)
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("failed to connect to CockroachDB after %d retries: %w", maxRetries, err)
|
|
}
|
|
|
|
func ConvertParamMapToRawQuery(queryParams map[string]string) string {
|
|
values := url.Values{}
|
|
for k, v := range queryParams {
|
|
values.Add(k, v)
|
|
}
|
|
return values.Encode()
|
|
}
|