mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-05-02 03:00:36 -04:00
This PR update the linking mechanism between Source and Tool.
Tools are directly linked to their Source, either by pointing to the
Source's functions or by assigning values from the source during Tool's
initialization. However, the existing approach means that any
modification to the Source after Tool's initialization might not be
reflected. To address this limitation, each tool should only store a
name reference to the Source, rather than direct link or assigned
values.
Tools will provide interface for `compatibleSource`. This will be used
to determine if a Source is compatible with the Tool.
```
type compatibleSource interface{
Client() http.Client
ProjectID() string
}
```
During `Invoke()`, the tool will run the following operations:
* retrieve Source from the `resourceManager` with source's named defined
in Tool's config
* validate Source via `compatibleSource interface{}`
* run the remaining `Invoke()` function. Fields that are needed is
retrieved directly from the source.
With this update, resource manager is also added as input to other
Tool's function that require access to source (e.g.
`RequiresClientAuthorization()`).
525 lines
15 KiB
Go
525 lines
15 KiB
Go
// Copyright 2025 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package firestorequerycollection
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
|
|
firestoreapi "cloud.google.com/go/firestore"
|
|
yaml "github.com/goccy/go-yaml"
|
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
|
"github.com/googleapis/genai-toolbox/internal/tools/firestore/util"
|
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
|
)
|
|
|
|
// Constants for tool configuration
|
|
const (
|
|
kind = "firestore-query-collection"
|
|
defaultLimit = 100
|
|
defaultAnalyze = false
|
|
maxFilterLength = 100 // Maximum filters to prevent abuse
|
|
)
|
|
|
|
// Parameter keys
|
|
const (
|
|
collectionPathKey = "collectionPath"
|
|
filtersKey = "filters"
|
|
orderByKey = "orderBy"
|
|
limitKey = "limit"
|
|
analyzeQueryKey = "analyzeQuery"
|
|
)
|
|
|
|
// Firestore operators
|
|
var validOperators = map[string]bool{
|
|
"<": true,
|
|
"<=": true,
|
|
">": true,
|
|
">=": true,
|
|
"==": true,
|
|
"!=": true,
|
|
"array-contains": true,
|
|
"array-contains-any": true,
|
|
"in": true,
|
|
"not-in": true,
|
|
}
|
|
|
|
// Error messages
|
|
const (
|
|
errMissingCollectionPath = "invalid or missing '%s' parameter"
|
|
errInvalidFilters = "invalid '%s' parameter; expected an array"
|
|
errFilterNotString = "filter at index %d is not a string"
|
|
errFilterParseFailed = "failed to parse filter at index %d: %w"
|
|
errInvalidOperator = "unsupported operator: %s. Valid operators are: %v"
|
|
errMissingFilterValue = "no value specified for filter on field '%s'"
|
|
errOrderByParseFailed = "failed to parse orderBy: %w"
|
|
errQueryExecutionFailed = "failed to execute query: %w"
|
|
errTooManyFilters = "too many filters provided: %d (maximum: %d)"
|
|
)
|
|
|
|
func init() {
|
|
if !tools.Register(kind, newConfig) {
|
|
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
|
}
|
|
}
|
|
|
|
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
|
actual := Config{Name: name}
|
|
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
|
return nil, err
|
|
}
|
|
return actual, nil
|
|
}
|
|
|
|
// compatibleSource defines the interface for sources that can provide a Firestore client
|
|
type compatibleSource interface {
|
|
FirestoreClient() *firestoreapi.Client
|
|
}
|
|
|
|
// Config represents the configuration for the Firestore query collection tool
|
|
type Config struct {
|
|
Name string `yaml:"name" validate:"required"`
|
|
Kind string `yaml:"kind" validate:"required"`
|
|
Source string `yaml:"source" validate:"required"`
|
|
Description string `yaml:"description" validate:"required"`
|
|
AuthRequired []string `yaml:"authRequired"`
|
|
}
|
|
|
|
// validate interface
|
|
var _ tools.ToolConfig = Config{}
|
|
|
|
// ToolConfigKind returns the kind of tool configuration
|
|
func (cfg Config) ToolConfigKind() string {
|
|
return kind
|
|
}
|
|
|
|
// Initialize creates a new Tool instance from the configuration
|
|
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
|
// Create parameters
|
|
params := createParameters()
|
|
|
|
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil)
|
|
|
|
// finish tool setup
|
|
t := Tool{
|
|
Config: cfg,
|
|
Parameters: params,
|
|
manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired},
|
|
mcpManifest: mcpManifest,
|
|
}
|
|
return t, nil
|
|
}
|
|
|
|
func (t Tool) ToConfig() tools.ToolConfig {
|
|
return t.Config
|
|
}
|
|
|
|
// createParameters creates the parameter definitions for the tool
|
|
func createParameters() parameters.Parameters {
|
|
collectionPathParameter := parameters.NewStringParameter(
|
|
collectionPathKey,
|
|
"The relative path to the Firestore collection to query (e.g., 'users' or 'users/userId/posts'). Note: This is a relative path, NOT an absolute path like 'projects/{project_id}/databases/{database_id}/documents/...'",
|
|
)
|
|
|
|
filtersDescription := `Array of filter objects to apply to the query. Each filter is a JSON string with:
|
|
- field: The field name to filter on
|
|
- op: The operator to use ("<", "<=", ">", ">=", "==", "!=", "array-contains", "array-contains-any", "in", "not-in")
|
|
- value: The value to compare against (can be string, number, boolean, or array)
|
|
Example: {"field": "age", "op": ">", "value": 18}`
|
|
|
|
filtersParameter := parameters.NewArrayParameter(
|
|
filtersKey,
|
|
filtersDescription,
|
|
parameters.NewStringParameter("item", "JSON string representation of a filter object"),
|
|
)
|
|
|
|
orderByParameter := parameters.NewStringParameter(
|
|
orderByKey,
|
|
"JSON string specifying the field and direction to order by (e.g., {\"field\": \"name\", \"direction\": \"ASCENDING\"}). Leave empty if not specified",
|
|
)
|
|
|
|
limitParameter := parameters.NewIntParameterWithDefault(
|
|
limitKey,
|
|
defaultLimit,
|
|
"The maximum number of documents to return",
|
|
)
|
|
|
|
analyzeQueryParameter := parameters.NewBooleanParameterWithDefault(
|
|
analyzeQueryKey,
|
|
defaultAnalyze,
|
|
"If true, returns query explain metrics including execution statistics",
|
|
)
|
|
|
|
return parameters.Parameters{
|
|
collectionPathParameter,
|
|
filtersParameter,
|
|
orderByParameter,
|
|
limitParameter,
|
|
analyzeQueryParameter,
|
|
}
|
|
}
|
|
|
|
// validate interface
|
|
var _ tools.Tool = Tool{}
|
|
|
|
// Tool represents the Firestore query collection tool
|
|
type Tool struct {
|
|
Config
|
|
Parameters parameters.Parameters `yaml:"parameters"`
|
|
manifest tools.Manifest
|
|
mcpManifest tools.McpManifest
|
|
}
|
|
|
|
// FilterConfig represents a filter for the query
|
|
type FilterConfig struct {
|
|
Field string `json:"field"`
|
|
Op string `json:"op"`
|
|
Value interface{} `json:"value"`
|
|
}
|
|
|
|
// Validate checks if the filter configuration is valid
|
|
func (f *FilterConfig) Validate() error {
|
|
if f.Field == "" {
|
|
return fmt.Errorf("filter field cannot be empty")
|
|
}
|
|
|
|
if !validOperators[f.Op] {
|
|
ops := make([]string, 0, len(validOperators))
|
|
for op := range validOperators {
|
|
ops = append(ops, op)
|
|
}
|
|
return fmt.Errorf(errInvalidOperator, f.Op, ops)
|
|
}
|
|
|
|
if f.Value == nil {
|
|
return fmt.Errorf(errMissingFilterValue, f.Field)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// OrderByConfig represents ordering configuration
|
|
type OrderByConfig struct {
|
|
Field string `json:"field"`
|
|
Direction string `json:"direction"`
|
|
}
|
|
|
|
// GetDirection returns the Firestore direction constant
|
|
func (o *OrderByConfig) GetDirection() firestoreapi.Direction {
|
|
if strings.EqualFold(o.Direction, "DESCENDING") {
|
|
return firestoreapi.Desc
|
|
}
|
|
return firestoreapi.Asc
|
|
}
|
|
|
|
// QueryResult represents a document result from the query
|
|
type QueryResult struct {
|
|
ID string `json:"id"`
|
|
Path string `json:"path"`
|
|
Data map[string]any `json:"data"`
|
|
CreateTime interface{} `json:"createTime,omitempty"`
|
|
UpdateTime interface{} `json:"updateTime,omitempty"`
|
|
ReadTime interface{} `json:"readTime,omitempty"`
|
|
}
|
|
|
|
// QueryResponse represents the full response including optional metrics
|
|
type QueryResponse struct {
|
|
Documents []QueryResult `json:"documents"`
|
|
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
|
|
}
|
|
|
|
// Invoke executes the Firestore query based on the provided parameters
|
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse parameters
|
|
queryParams, err := t.parseQueryParameters(params)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Build the query
|
|
query, err := t.buildQuery(source, queryParams)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Execute the query and return results
|
|
return t.executeQuery(ctx, query, queryParams.AnalyzeQuery)
|
|
}
|
|
|
|
// queryParameters holds all parsed query parameters
|
|
type queryParameters struct {
|
|
CollectionPath string
|
|
Filters []FilterConfig
|
|
OrderBy *OrderByConfig
|
|
Limit int
|
|
AnalyzeQuery bool
|
|
}
|
|
|
|
// parseQueryParameters extracts and validates parameters from the input
|
|
func (t Tool) parseQueryParameters(params parameters.ParamValues) (*queryParameters, error) {
|
|
mapParams := params.AsMap()
|
|
|
|
// Get collection path
|
|
collectionPath, ok := mapParams[collectionPathKey].(string)
|
|
if !ok || collectionPath == "" {
|
|
return nil, fmt.Errorf(errMissingCollectionPath, collectionPathKey)
|
|
}
|
|
|
|
// Validate collection path
|
|
if err := util.ValidateCollectionPath(collectionPath); err != nil {
|
|
return nil, fmt.Errorf("invalid collection path: %w", err)
|
|
}
|
|
|
|
result := &queryParameters{
|
|
CollectionPath: collectionPath,
|
|
Limit: defaultLimit,
|
|
AnalyzeQuery: defaultAnalyze,
|
|
}
|
|
|
|
// Parse filters
|
|
if filtersRaw, ok := mapParams[filtersKey]; ok && filtersRaw != nil {
|
|
filters, err := t.parseFilters(filtersRaw)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result.Filters = filters
|
|
}
|
|
|
|
// Parse orderBy
|
|
if orderByRaw, ok := mapParams[orderByKey]; ok && orderByRaw != nil {
|
|
orderBy, err := t.parseOrderBy(orderByRaw)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result.OrderBy = orderBy
|
|
}
|
|
|
|
// Parse limit
|
|
if limit, ok := mapParams[limitKey].(int); ok {
|
|
result.Limit = limit
|
|
}
|
|
|
|
// Parse analyze
|
|
if analyze, ok := mapParams[analyzeQueryKey].(bool); ok {
|
|
result.AnalyzeQuery = analyze
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// parseFilters parses and validates filter configurations
|
|
func (t Tool) parseFilters(filtersRaw interface{}) ([]FilterConfig, error) {
|
|
filters, ok := filtersRaw.([]any)
|
|
if !ok {
|
|
return nil, fmt.Errorf(errInvalidFilters, filtersKey)
|
|
}
|
|
|
|
if len(filters) > maxFilterLength {
|
|
return nil, fmt.Errorf(errTooManyFilters, len(filters), maxFilterLength)
|
|
}
|
|
|
|
result := make([]FilterConfig, 0, len(filters))
|
|
for i, filterRaw := range filters {
|
|
filterJSON, ok := filterRaw.(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf(errFilterNotString, i)
|
|
}
|
|
|
|
var filter FilterConfig
|
|
if err := json.Unmarshal([]byte(filterJSON), &filter); err != nil {
|
|
return nil, fmt.Errorf(errFilterParseFailed, i, err)
|
|
}
|
|
|
|
if err := filter.Validate(); err != nil {
|
|
return nil, fmt.Errorf("filter at index %d is invalid: %w", i, err)
|
|
}
|
|
|
|
result = append(result, filter)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// parseOrderBy parses the orderBy configuration
|
|
func (t Tool) parseOrderBy(orderByRaw interface{}) (*OrderByConfig, error) {
|
|
orderByJSON, ok := orderByRaw.(string)
|
|
if !ok || orderByJSON == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
var orderBy OrderByConfig
|
|
if err := json.Unmarshal([]byte(orderByJSON), &orderBy); err != nil {
|
|
return nil, fmt.Errorf(errOrderByParseFailed, err)
|
|
}
|
|
|
|
if orderBy.Field == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
return &orderBy, nil
|
|
}
|
|
|
|
// buildQuery constructs the Firestore query from parameters
|
|
func (t Tool) buildQuery(source compatibleSource, params *queryParameters) (*firestoreapi.Query, error) {
|
|
collection := source.FirestoreClient().Collection(params.CollectionPath)
|
|
query := collection.Query
|
|
|
|
// Apply filters
|
|
if len(params.Filters) > 0 {
|
|
filterConditions := make([]firestoreapi.EntityFilter, 0, len(params.Filters))
|
|
for _, filter := range params.Filters {
|
|
filterConditions = append(filterConditions, firestoreapi.PropertyFilter{
|
|
Path: filter.Field,
|
|
Operator: filter.Op,
|
|
Value: filter.Value,
|
|
})
|
|
}
|
|
|
|
query = query.WhereEntity(firestoreapi.AndFilter{
|
|
Filters: filterConditions,
|
|
})
|
|
}
|
|
|
|
// Apply ordering
|
|
if params.OrderBy != nil {
|
|
query = query.OrderBy(params.OrderBy.Field, params.OrderBy.GetDirection())
|
|
}
|
|
|
|
// Apply limit
|
|
query = query.Limit(params.Limit)
|
|
|
|
// Apply analyze options
|
|
if params.AnalyzeQuery {
|
|
query = query.WithRunOptions(firestoreapi.ExplainOptions{
|
|
Analyze: true,
|
|
})
|
|
}
|
|
|
|
return &query, nil
|
|
}
|
|
|
|
// executeQuery runs the query and formats the results
|
|
func (t Tool) executeQuery(ctx context.Context, query *firestoreapi.Query, analyzeQuery bool) (any, error) {
|
|
docIterator := query.Documents(ctx)
|
|
docs, err := docIterator.GetAll()
|
|
if err != nil {
|
|
return nil, fmt.Errorf(errQueryExecutionFailed, err)
|
|
}
|
|
|
|
// Convert results to structured format
|
|
results := make([]QueryResult, len(docs))
|
|
for i, doc := range docs {
|
|
results[i] = QueryResult{
|
|
ID: doc.Ref.ID,
|
|
Path: doc.Ref.Path,
|
|
Data: doc.Data(),
|
|
CreateTime: doc.CreateTime,
|
|
UpdateTime: doc.UpdateTime,
|
|
ReadTime: doc.ReadTime,
|
|
}
|
|
}
|
|
|
|
// Return with explain metrics if requested
|
|
if analyzeQuery {
|
|
explainMetrics, err := t.getExplainMetrics(docIterator)
|
|
if err == nil && explainMetrics != nil {
|
|
response := QueryResponse{
|
|
Documents: results,
|
|
ExplainMetrics: explainMetrics,
|
|
}
|
|
return response, nil
|
|
}
|
|
}
|
|
|
|
// Return just the documents
|
|
resultsAny := make([]any, len(results))
|
|
for i, r := range results {
|
|
resultsAny[i] = r
|
|
}
|
|
return resultsAny, nil
|
|
}
|
|
|
|
// getExplainMetrics extracts explain metrics from the query iterator
|
|
func (t Tool) getExplainMetrics(docIterator *firestoreapi.DocumentIterator) (map[string]any, error) {
|
|
explainMetrics, err := docIterator.ExplainMetrics()
|
|
if err != nil || explainMetrics == nil {
|
|
return nil, err
|
|
}
|
|
|
|
metricsData := make(map[string]any)
|
|
|
|
// Add plan summary if available
|
|
if explainMetrics.PlanSummary != nil {
|
|
planSummary := make(map[string]any)
|
|
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
|
|
metricsData["planSummary"] = planSummary
|
|
}
|
|
|
|
// Add execution stats if available
|
|
if explainMetrics.ExecutionStats != nil {
|
|
executionStats := make(map[string]any)
|
|
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
|
|
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
|
|
|
|
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
|
|
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
|
|
}
|
|
|
|
if explainMetrics.ExecutionStats.DebugStats != nil {
|
|
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
|
|
}
|
|
|
|
metricsData["executionStats"] = executionStats
|
|
}
|
|
|
|
return metricsData, nil
|
|
}
|
|
|
|
// ParseParams parses and validates input parameters
|
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
|
return parameters.ParseParams(t.Parameters, data, claims)
|
|
}
|
|
|
|
// Manifest returns the tool manifest
|
|
func (t Tool) Manifest() tools.Manifest {
|
|
return t.manifest
|
|
}
|
|
|
|
// McpManifest returns the MCP manifest
|
|
func (t Tool) McpManifest() tools.McpManifest {
|
|
return t.mcpManifest
|
|
}
|
|
|
|
// Authorized checks if the tool is authorized based on verified auth services
|
|
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
|
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
|
}
|
|
|
|
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
|
return false, nil
|
|
}
|
|
|
|
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
|
return "Authorization", nil
|
|
}
|