mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-12 08:58:28 -05:00
Compare commits
1 Commits
source-imp
...
source-imp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3d5f166443 |
@@ -19,8 +19,6 @@ import (
|
||||
"fmt"
|
||||
|
||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
||||
"cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"github.com/cenkalti/backoff/v5"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
@@ -123,101 +121,3 @@ func initDataplexConnection(
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *Source) LookupEntry(ctx context.Context, name string, view int, aspectTypes []string, entry string) (*dataplexpb.Entry, error) {
|
||||
viewMap := map[int]dataplexpb.EntryView{
|
||||
1: dataplexpb.EntryView_BASIC,
|
||||
2: dataplexpb.EntryView_FULL,
|
||||
3: dataplexpb.EntryView_CUSTOM,
|
||||
4: dataplexpb.EntryView_ALL,
|
||||
}
|
||||
req := &dataplexpb.LookupEntryRequest{
|
||||
Name: name,
|
||||
View: viewMap[view],
|
||||
AspectTypes: aspectTypes,
|
||||
Entry: entry,
|
||||
}
|
||||
result, err := s.CatalogClient().LookupEntry(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Source) searchRequest(ctx context.Context, query string, pageSize int, orderBy string) (*dataplexapi.SearchEntriesResultIterator, error) {
|
||||
// Create SearchEntriesRequest with the provided parameters
|
||||
req := &dataplexpb.SearchEntriesRequest{
|
||||
Query: query,
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", s.ProjectID()),
|
||||
PageSize: int32(pageSize),
|
||||
OrderBy: orderBy,
|
||||
SemanticSearch: true,
|
||||
}
|
||||
|
||||
// Perform the search using the CatalogClient - this will return an iterator
|
||||
it := s.CatalogClient().SearchEntries(ctx, req)
|
||||
if it == nil {
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", s.ProjectID())
|
||||
}
|
||||
return it, nil
|
||||
}
|
||||
|
||||
func (s *Source) SearchAspectTypes(ctx context.Context, query string, pageSize int, orderBy string) ([]*dataplexpb.AspectType, error) {
|
||||
q := query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype"
|
||||
it, err := s.searchRequest(ctx, q, pageSize, orderBy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Iterate through the search results and call GetAspectType for each result using the resource name
|
||||
var results []*dataplexpb.AspectType
|
||||
for {
|
||||
entry, err := it.Next()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
// Create an instance of exponential backoff with default values for retrying GetAspectType calls
|
||||
// InitialInterval, RandomizationFactor, Multiplier, MaxInterval = 500 ms, 0.5, 1.5, 60 s
|
||||
getAspectBackOff := backoff.NewExponentialBackOff()
|
||||
|
||||
resourceName := entry.DataplexEntry.GetEntrySource().Resource
|
||||
getAspectTypeReq := &dataplexpb.GetAspectTypeRequest{
|
||||
Name: resourceName,
|
||||
}
|
||||
|
||||
operation := func() (*dataplexpb.AspectType, error) {
|
||||
aspectType, err := s.CatalogClient().GetAspectType(ctx, getAspectTypeReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err)
|
||||
}
|
||||
return aspectType, nil
|
||||
}
|
||||
|
||||
// Retry the GetAspectType operation with exponential backoff
|
||||
aspectType, err := backoff.Retry(ctx, operation, backoff.WithBackOff(getAspectBackOff))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get aspect type after retries for entry %q: %w", resourceName, err)
|
||||
}
|
||||
|
||||
results = append(results, aspectType)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *Source) SearchEntries(ctx context.Context, query string, pageSize int, orderBy string) ([]*dataplexpb.SearchEntriesResult, error) {
|
||||
it, err := s.searchRequest(ctx, query, pageSize, orderBy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var results []*dataplexpb.SearchEntriesResult
|
||||
for {
|
||||
entry, err := it.Next()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
results = append(results, entry)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
@@ -16,7 +16,10 @@ package firestore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/firestore"
|
||||
"github.com/goccy/go-yaml"
|
||||
@@ -25,6 +28,7 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/api/firebaserules/v1"
|
||||
"google.golang.org/api/option"
|
||||
"google.golang.org/genproto/googleapis/type/latlng"
|
||||
)
|
||||
|
||||
const SourceKind string = "firestore"
|
||||
@@ -113,6 +117,476 @@ func (s *Source) GetDatabaseId() string {
|
||||
return s.Database
|
||||
}
|
||||
|
||||
// FirestoreValueToJSON converts a Firestore value to a simplified JSON representation
|
||||
// This removes type information and returns plain values
|
||||
func FirestoreValueToJSON(value any) any {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case time.Time:
|
||||
return v.Format(time.RFC3339Nano)
|
||||
case *latlng.LatLng:
|
||||
return map[string]any{
|
||||
"latitude": v.Latitude,
|
||||
"longitude": v.Longitude,
|
||||
}
|
||||
case []byte:
|
||||
return base64.StdEncoding.EncodeToString(v)
|
||||
case []any:
|
||||
result := make([]any, len(v))
|
||||
for i, item := range v {
|
||||
result[i] = FirestoreValueToJSON(item)
|
||||
}
|
||||
return result
|
||||
case map[string]any:
|
||||
result := make(map[string]any)
|
||||
for k, val := range v {
|
||||
result[k] = FirestoreValueToJSON(val)
|
||||
}
|
||||
return result
|
||||
case *firestore.DocumentRef:
|
||||
return v.Path
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// BuildQuery constructs the Firestore query from parameters
|
||||
func (s *Source) BuildQuery(collectionPath string, filter firestore.EntityFilter, selectFields []string, field string, direction firestore.Direction, limit int, analyzeQuery bool) (*firestore.Query, error) {
|
||||
collection := s.FirestoreClient().Collection(collectionPath)
|
||||
query := collection.Query
|
||||
|
||||
// Process and apply filters if template is provided
|
||||
if filter != nil {
|
||||
query = query.WhereEntity(filter)
|
||||
}
|
||||
if len(selectFields) > 0 {
|
||||
query = query.Select(selectFields...)
|
||||
}
|
||||
if field != "" {
|
||||
query = query.OrderBy(field, direction)
|
||||
}
|
||||
query = query.Limit(limit)
|
||||
|
||||
// Apply analyze options if enabled
|
||||
if analyzeQuery {
|
||||
query = query.WithRunOptions(firestore.ExplainOptions{
|
||||
Analyze: true,
|
||||
})
|
||||
}
|
||||
|
||||
return &query, nil
|
||||
}
|
||||
|
||||
// QueryResult represents a document result from the query
|
||||
type QueryResult struct {
|
||||
ID string `json:"id"`
|
||||
Path string `json:"path"`
|
||||
Data map[string]any `json:"data"`
|
||||
CreateTime any `json:"createTime,omitempty"`
|
||||
UpdateTime any `json:"updateTime,omitempty"`
|
||||
ReadTime any `json:"readTime,omitempty"`
|
||||
}
|
||||
|
||||
// QueryResponse represents the full response including optional metrics
|
||||
type QueryResponse struct {
|
||||
Documents []QueryResult `json:"documents"`
|
||||
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
|
||||
}
|
||||
|
||||
// ExecuteQuery runs the query and formats the results
|
||||
func (s *Source) ExecuteQuery(ctx context.Context, query *firestore.Query, analyzeQuery bool) (any, error) {
|
||||
docIterator := query.Documents(ctx)
|
||||
docs, err := docIterator.GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute query: %w", err)
|
||||
}
|
||||
// Convert results to structured format
|
||||
results := make([]QueryResult, len(docs))
|
||||
for i, doc := range docs {
|
||||
results[i] = QueryResult{
|
||||
ID: doc.Ref.ID,
|
||||
Path: doc.Ref.Path,
|
||||
Data: doc.Data(),
|
||||
CreateTime: doc.CreateTime,
|
||||
UpdateTime: doc.UpdateTime,
|
||||
ReadTime: doc.ReadTime,
|
||||
}
|
||||
}
|
||||
|
||||
// Return with explain metrics if requested
|
||||
if analyzeQuery {
|
||||
explainMetrics, err := getExplainMetrics(docIterator)
|
||||
if err == nil && explainMetrics != nil {
|
||||
response := QueryResponse{
|
||||
Documents: results,
|
||||
ExplainMetrics: explainMetrics,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// getExplainMetrics extracts explain metrics from the query iterator
|
||||
func getExplainMetrics(docIterator *firestore.DocumentIterator) (map[string]any, error) {
|
||||
explainMetrics, err := docIterator.ExplainMetrics()
|
||||
if err != nil || explainMetrics == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metricsData := make(map[string]any)
|
||||
|
||||
// Add plan summary if available
|
||||
if explainMetrics.PlanSummary != nil {
|
||||
planSummary := make(map[string]any)
|
||||
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
|
||||
metricsData["planSummary"] = planSummary
|
||||
}
|
||||
|
||||
// Add execution stats if available
|
||||
if explainMetrics.ExecutionStats != nil {
|
||||
executionStats := make(map[string]any)
|
||||
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
|
||||
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
|
||||
|
||||
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
|
||||
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
|
||||
}
|
||||
|
||||
if explainMetrics.ExecutionStats.DebugStats != nil {
|
||||
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
|
||||
}
|
||||
|
||||
metricsData["executionStats"] = executionStats
|
||||
}
|
||||
|
||||
return metricsData, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetDocuments(ctx context.Context, documentPaths []string) ([]any, error) {
|
||||
// Create document references from paths
|
||||
docRefs := make([]*firestore.DocumentRef, len(documentPaths))
|
||||
for i, path := range documentPaths {
|
||||
docRefs[i] = s.FirestoreClient().Doc(path)
|
||||
}
|
||||
|
||||
// Get all documents
|
||||
snapshots, err := s.FirestoreClient().GetAll(ctx, docRefs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get documents: %w", err)
|
||||
}
|
||||
|
||||
// Convert snapshots to response data
|
||||
results := make([]any, len(snapshots))
|
||||
for i, snapshot := range snapshots {
|
||||
docData := make(map[string]any)
|
||||
docData["path"] = documentPaths[i]
|
||||
docData["exists"] = snapshot.Exists()
|
||||
|
||||
if snapshot.Exists() {
|
||||
docData["data"] = snapshot.Data()
|
||||
docData["createTime"] = snapshot.CreateTime
|
||||
docData["updateTime"] = snapshot.UpdateTime
|
||||
docData["readTime"] = snapshot.ReadTime
|
||||
}
|
||||
|
||||
results[i] = docData
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *Source) AddDocuments(ctx context.Context, collectionPath string, documentData any, returnData bool) (map[string]any, error) {
|
||||
// Get the collection reference
|
||||
collection := s.FirestoreClient().Collection(collectionPath)
|
||||
|
||||
// Add the document to the collection
|
||||
docRef, writeResult, err := collection.Add(ctx, documentData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add document: %w", err)
|
||||
}
|
||||
// Build the response
|
||||
response := map[string]any{
|
||||
"documentPath": docRef.Path,
|
||||
"createTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
|
||||
}
|
||||
// Add document data if requested
|
||||
if returnData {
|
||||
// Fetch the updated document to return the current state
|
||||
snapshot, err := docRef.Get(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
|
||||
}
|
||||
// Convert the document data back to simple JSON format
|
||||
simplifiedData := FirestoreValueToJSON(snapshot.Data())
|
||||
response["documentData"] = simplifiedData
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *Source) UpdateDocument(ctx context.Context, documentPath string, updates []firestore.Update, documentData any, returnData bool) (map[string]any, error) {
|
||||
// Get the document reference
|
||||
docRef := s.FirestoreClient().Doc(documentPath)
|
||||
|
||||
// Prepare update data
|
||||
var writeResult *firestore.WriteResult
|
||||
var writeErr error
|
||||
|
||||
if len(updates) > 0 {
|
||||
writeResult, writeErr = docRef.Update(ctx, updates)
|
||||
} else {
|
||||
writeResult, writeErr = docRef.Set(ctx, documentData, firestore.MergeAll)
|
||||
}
|
||||
|
||||
if writeErr != nil {
|
||||
return nil, fmt.Errorf("failed to update document: %w", writeErr)
|
||||
}
|
||||
|
||||
// Build the response
|
||||
response := map[string]any{
|
||||
"documentPath": docRef.Path,
|
||||
"updateTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
|
||||
}
|
||||
|
||||
// Add document data if requested
|
||||
if returnData {
|
||||
// Fetch the updated document to return the current state
|
||||
snapshot, err := docRef.Get(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
|
||||
}
|
||||
// Convert the document data to simple JSON format
|
||||
simplifiedData := FirestoreValueToJSON(snapshot.Data())
|
||||
response["documentData"] = simplifiedData
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *Source) DeleteDocuments(ctx context.Context, documentPaths []string) ([]any, error) {
|
||||
// Create a BulkWriter to handle multiple deletions efficiently
|
||||
bulkWriter := s.FirestoreClient().BulkWriter(ctx)
|
||||
|
||||
// Keep track of jobs for each document
|
||||
jobs := make([]*firestore.BulkWriterJob, len(documentPaths))
|
||||
|
||||
// Add all delete operations to the BulkWriter
|
||||
for i, path := range documentPaths {
|
||||
docRef := s.FirestoreClient().Doc(path)
|
||||
job, err := bulkWriter.Delete(docRef)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add delete operation for document %q: %w", path, err)
|
||||
}
|
||||
jobs[i] = job
|
||||
}
|
||||
|
||||
// End the BulkWriter to execute all operations
|
||||
bulkWriter.End()
|
||||
|
||||
// Collect results
|
||||
results := make([]any, len(documentPaths))
|
||||
for i, job := range jobs {
|
||||
docData := make(map[string]any)
|
||||
docData["path"] = documentPaths[i]
|
||||
|
||||
// Wait for the job to complete and get the result
|
||||
_, err := job.Results()
|
||||
if err != nil {
|
||||
docData["success"] = false
|
||||
docData["error"] = err.Error()
|
||||
} else {
|
||||
docData["success"] = true
|
||||
}
|
||||
|
||||
results[i] = docData
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *Source) ListCollections(ctx context.Context, parentPath string) ([]any, error) {
|
||||
var collectionRefs []*firestore.CollectionRef
|
||||
var err error
|
||||
if parentPath != "" {
|
||||
// List subcollections of the specified document
|
||||
docRef := s.FirestoreClient().Doc(parentPath)
|
||||
collectionRefs, err = docRef.Collections(ctx).GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list subcollections of document %q: %w", parentPath, err)
|
||||
}
|
||||
} else {
|
||||
// List root collections
|
||||
collectionRefs, err = s.FirestoreClient().Collections(ctx).GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list root collections: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert collection references to response data
|
||||
results := make([]any, len(collectionRefs))
|
||||
for i, collRef := range collectionRefs {
|
||||
collData := make(map[string]any)
|
||||
collData["id"] = collRef.ID
|
||||
collData["path"] = collRef.Path
|
||||
|
||||
// If this is a subcollection, include parent information
|
||||
if collRef.Parent != nil {
|
||||
collData["parent"] = collRef.Parent.Path
|
||||
}
|
||||
results[i] = collData
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetRules(ctx context.Context) (any, error) {
|
||||
// Get the latest release for Firestore
|
||||
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", s.GetProjectId(), s.GetDatabaseId())
|
||||
release, err := s.FirebaseRulesClient().Projects.Releases.Get(releaseName).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get latest Firestore release: %w", err)
|
||||
}
|
||||
|
||||
if release.RulesetName == "" {
|
||||
return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", s.GetProjectId(), s.GetDatabaseId())
|
||||
}
|
||||
|
||||
// Get the ruleset content
|
||||
ruleset, err := s.FirebaseRulesClient().Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get ruleset content: %w", err)
|
||||
}
|
||||
|
||||
if ruleset.Source == nil || len(ruleset.Source.Files) == 0 {
|
||||
return nil, fmt.Errorf("no rules files found in ruleset")
|
||||
}
|
||||
|
||||
return ruleset, nil
|
||||
}
|
||||
|
||||
// SourcePosition represents the location of an issue in the source
|
||||
type SourcePosition struct {
|
||||
FileName string `json:"fileName,omitempty"`
|
||||
Line int64 `json:"line"` // 1-based
|
||||
Column int64 `json:"column"` // 1-based
|
||||
CurrentOffset int64 `json:"currentOffset"` // 0-based, inclusive start
|
||||
EndOffset int64 `json:"endOffset"` // 0-based, exclusive end
|
||||
}
|
||||
|
||||
// Issue represents a validation issue in the rules
|
||||
type Issue struct {
|
||||
SourcePosition SourcePosition `json:"sourcePosition"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
}
|
||||
|
||||
// ValidationResult represents the result of rules validation
|
||||
type ValidationResult struct {
|
||||
Valid bool `json:"valid"`
|
||||
IssueCount int `json:"issueCount"`
|
||||
FormattedIssues string `json:"formattedIssues,omitempty"`
|
||||
RawIssues []Issue `json:"rawIssues,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Source) ValidateRules(ctx context.Context, sourceParam string) (any, error) {
|
||||
// Create test request
|
||||
testRequest := &firebaserules.TestRulesetRequest{
|
||||
Source: &firebaserules.Source{
|
||||
Files: []*firebaserules.File{
|
||||
{
|
||||
Name: "firestore.rules",
|
||||
Content: sourceParam,
|
||||
},
|
||||
},
|
||||
},
|
||||
// We don't need test cases for validation only
|
||||
TestSuite: &firebaserules.TestSuite{
|
||||
TestCases: []*firebaserules.TestCase{},
|
||||
},
|
||||
}
|
||||
// Call the test API
|
||||
projectName := fmt.Sprintf("projects/%s", s.GetProjectId())
|
||||
response, err := s.FirebaseRulesClient().Projects.Test(projectName, testRequest).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate rules: %w", err)
|
||||
}
|
||||
|
||||
// Process the response
|
||||
if len(response.Issues) == 0 {
|
||||
return ValidationResult{
|
||||
Valid: true,
|
||||
IssueCount: 0,
|
||||
FormattedIssues: "✓ No errors detected. Rules are valid.",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Convert issues to our format
|
||||
issues := make([]Issue, len(response.Issues))
|
||||
for i, issue := range response.Issues {
|
||||
issues[i] = Issue{
|
||||
Description: issue.Description,
|
||||
Severity: issue.Severity,
|
||||
SourcePosition: SourcePosition{
|
||||
FileName: issue.SourcePosition.FileName,
|
||||
Line: issue.SourcePosition.Line,
|
||||
Column: issue.SourcePosition.Column,
|
||||
CurrentOffset: issue.SourcePosition.CurrentOffset,
|
||||
EndOffset: issue.SourcePosition.EndOffset,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Format issues
|
||||
sourceLines := strings.Split(sourceParam, "\n")
|
||||
var formattedOutput []string
|
||||
|
||||
formattedOutput = append(formattedOutput, fmt.Sprintf("Found %d issue(s) in rules source:\n", len(issues)))
|
||||
|
||||
for _, issue := range issues {
|
||||
issueString := fmt.Sprintf("%s: %s [Ln %d, Col %d]",
|
||||
issue.Severity,
|
||||
issue.Description,
|
||||
issue.SourcePosition.Line,
|
||||
issue.SourcePosition.Column)
|
||||
|
||||
if issue.SourcePosition.Line > 0 {
|
||||
lineIndex := int(issue.SourcePosition.Line - 1) // 0-based index
|
||||
if lineIndex >= 0 && lineIndex < len(sourceLines) {
|
||||
errorLine := sourceLines[lineIndex]
|
||||
issueString += fmt.Sprintf("\n```\n%s", errorLine)
|
||||
|
||||
// Add carets if we have column and offset information
|
||||
if issue.SourcePosition.Column > 0 &&
|
||||
issue.SourcePosition.CurrentOffset >= 0 &&
|
||||
issue.SourcePosition.EndOffset > issue.SourcePosition.CurrentOffset {
|
||||
|
||||
startColumn := int(issue.SourcePosition.Column - 1) // 0-based
|
||||
errorTokenLength := int(issue.SourcePosition.EndOffset - issue.SourcePosition.CurrentOffset)
|
||||
|
||||
if startColumn >= 0 && errorTokenLength > 0 && startColumn <= len(errorLine) {
|
||||
padding := strings.Repeat(" ", startColumn)
|
||||
carets := strings.Repeat("^", errorTokenLength)
|
||||
issueString += fmt.Sprintf("\n%s%s", padding, carets)
|
||||
}
|
||||
}
|
||||
issueString += "\n```"
|
||||
}
|
||||
}
|
||||
|
||||
formattedOutput = append(formattedOutput, issueString)
|
||||
}
|
||||
|
||||
formattedIssues := strings.Join(formattedOutput, "\n\n")
|
||||
|
||||
return ValidationResult{
|
||||
Valid: false,
|
||||
IssueCount: len(issues),
|
||||
FormattedIssues: formattedIssues,
|
||||
RawIssues: issues,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func initFirestoreConnection(
|
||||
ctx context.Context,
|
||||
tracer trace.Tracer,
|
||||
|
||||
@@ -16,6 +16,7 @@ package firestore_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -128,3 +129,37 @@ func TestFailParseFromYamlFirestore(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirestoreValueToJSON_RoundTrip(t *testing.T) {
|
||||
// Test round-trip conversion
|
||||
original := map[string]any{
|
||||
"name": "Test",
|
||||
"count": int64(42),
|
||||
"price": 19.99,
|
||||
"active": true,
|
||||
"tags": []any{"tag1", "tag2"},
|
||||
"metadata": map[string]any{
|
||||
"created": time.Now(),
|
||||
},
|
||||
"nullField": nil,
|
||||
}
|
||||
|
||||
// Convert to JSON representation
|
||||
jsonRepresentation := firestore.FirestoreValueToJSON(original)
|
||||
|
||||
// Verify types are simplified
|
||||
jsonMap, ok := jsonRepresentation.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected map, got %T", jsonRepresentation)
|
||||
}
|
||||
|
||||
// Time should be converted to string
|
||||
metadata, ok := jsonMap["metadata"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("metadata should be a map, got %T", jsonMap["metadata"])
|
||||
}
|
||||
_, ok = metadata["created"].(string)
|
||||
if !ok {
|
||||
t.Errorf("created should be a string, got %T", metadata["created"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,9 +16,7 @@ package http
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
@@ -145,28 +143,3 @@ func (s *Source) HttpQueryParams() map[string]string {
|
||||
func (s *Source) Client() *http.Client {
|
||||
return s.client
|
||||
}
|
||||
|
||||
func (s *Source) RunRequest(req *http.Request) (any, error) {
|
||||
// Make request and fetch response
|
||||
resp, err := s.Client().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making HTTP request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var body []byte
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var data any
|
||||
if err = json.Unmarshal(body, &data); err != nil {
|
||||
// if unable to unmarshal data, return result as string.
|
||||
return string(body), nil
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
@@ -16,21 +16,15 @@ package serverlessspark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
|
||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
longrunning "cloud.google.com/go/longrunning/autogen"
|
||||
"cloud.google.com/go/longrunning/autogen/longrunningpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/api/iterator"
|
||||
"google.golang.org/api/option"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
const SourceKind string = "serverless-spark"
|
||||
@@ -127,168 +121,3 @@ func (s *Source) Close() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Source) CancelOperation(ctx context.Context, operation string) (any, error) {
|
||||
req := &longrunningpb.CancelOperationRequest{
|
||||
Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", s.GetProject(), s.GetLocation(), operation),
|
||||
}
|
||||
client, err := s.GetOperationsClient(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get operations client: %w", err)
|
||||
}
|
||||
err = client.CancelOperation(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to cancel operation: %w", err)
|
||||
}
|
||||
return fmt.Sprintf("Cancelled [%s].", operation), nil
|
||||
}
|
||||
|
||||
func (s *Source) CreateBatch(ctx context.Context, batch *dataprocpb.Batch) (map[string]any, error) {
|
||||
req := &dataprocpb.CreateBatchRequest{
|
||||
Parent: fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation()),
|
||||
Batch: batch,
|
||||
}
|
||||
|
||||
client := s.GetBatchControllerClient()
|
||||
op, err := client.CreateBatch(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create batch: %w", err)
|
||||
}
|
||||
meta, err := op.Metadata()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get create batch op metadata: %w", err)
|
||||
}
|
||||
|
||||
projectID, location, batchID, err := ExtractBatchDetails(meta.GetBatch())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error extracting batch details from name %q: %v", meta.GetBatch(), err)
|
||||
}
|
||||
consoleUrl := BatchConsoleURL(projectID, location, batchID)
|
||||
logsUrl := BatchLogsURL(projectID, location, batchID, meta.GetCreateTime().AsTime(), time.Time{})
|
||||
|
||||
wrappedResult := map[string]any{
|
||||
"opMetadata": meta,
|
||||
"consoleUrl": consoleUrl,
|
||||
"logsUrl": logsUrl,
|
||||
}
|
||||
return wrappedResult, nil
|
||||
}
|
||||
|
||||
// ListBatchesResponse is the response from the list batches API.
|
||||
type ListBatchesResponse struct {
|
||||
Batches []Batch `json:"batches"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
}
|
||||
|
||||
// Batch represents a single batch job.
|
||||
type Batch struct {
|
||||
Name string `json:"name"`
|
||||
UUID string `json:"uuid"`
|
||||
State string `json:"state"`
|
||||
Creator string `json:"creator"`
|
||||
CreateTime string `json:"createTime"`
|
||||
Operation string `json:"operation"`
|
||||
ConsoleURL string `json:"consoleUrl"`
|
||||
LogsURL string `json:"logsUrl"`
|
||||
}
|
||||
|
||||
func (s *Source) ListBatches(ctx context.Context, ps *int, pt, filter string) (any, error) {
|
||||
client := s.GetBatchControllerClient()
|
||||
parent := fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation())
|
||||
req := &dataprocpb.ListBatchesRequest{
|
||||
Parent: parent,
|
||||
OrderBy: "create_time desc",
|
||||
}
|
||||
|
||||
if ps != nil {
|
||||
req.PageSize = int32(*ps)
|
||||
}
|
||||
if pt != "" {
|
||||
req.PageToken = pt
|
||||
}
|
||||
if filter != "" {
|
||||
req.Filter = filter
|
||||
}
|
||||
|
||||
it := client.ListBatches(ctx, req)
|
||||
pager := iterator.NewPager(it, int(req.PageSize), req.PageToken)
|
||||
|
||||
var batchPbs []*dataprocpb.Batch
|
||||
nextPageToken, err := pager.NextPage(&batchPbs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list batches: %w", err)
|
||||
}
|
||||
|
||||
batches, err := ToBatches(batchPbs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ListBatchesResponse{Batches: batches, NextPageToken: nextPageToken}, nil
|
||||
}
|
||||
|
||||
// ToBatches converts a slice of protobuf Batch messages to a slice of Batch structs.
|
||||
func ToBatches(batchPbs []*dataprocpb.Batch) ([]Batch, error) {
|
||||
batches := make([]Batch, 0, len(batchPbs))
|
||||
for _, batchPb := range batchPbs {
|
||||
consoleUrl, err := BatchConsoleURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating console url: %v", err)
|
||||
}
|
||||
logsUrl, err := BatchLogsURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating logs url: %v", err)
|
||||
}
|
||||
batch := Batch{
|
||||
Name: batchPb.Name,
|
||||
UUID: batchPb.Uuid,
|
||||
State: batchPb.State.Enum().String(),
|
||||
Creator: batchPb.Creator,
|
||||
CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339),
|
||||
Operation: batchPb.Operation,
|
||||
ConsoleURL: consoleUrl,
|
||||
LogsURL: logsUrl,
|
||||
}
|
||||
batches = append(batches, batch)
|
||||
}
|
||||
return batches, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetBatch(ctx context.Context, name string) (map[string]any, error) {
|
||||
client := s.GetBatchControllerClient()
|
||||
req := &dataprocpb.GetBatchRequest{
|
||||
Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", s.GetProject(), s.GetLocation(), name),
|
||||
}
|
||||
|
||||
batchPb, err := client.GetBatch(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get batch: %w", err)
|
||||
}
|
||||
|
||||
jsonBytes, err := protojson.Marshal(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal batch to JSON: %w", err)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(jsonBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal batch JSON: %w", err)
|
||||
}
|
||||
|
||||
consoleUrl, err := BatchConsoleURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating console url: %v", err)
|
||||
}
|
||||
logsUrl, err := BatchLogsURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating logs url: %v", err)
|
||||
}
|
||||
|
||||
wrappedResult := map[string]any{
|
||||
"consoleUrl": consoleUrl,
|
||||
"logsUrl": logsUrl,
|
||||
"batch": result,
|
||||
}
|
||||
|
||||
return wrappedResult, nil
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
||||
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
@@ -43,7 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
LookupEntry(context.Context, string, int, []string, string) (*dataplexpb.Entry, error)
|
||||
CatalogClient() *dataplexapi.CatalogClient
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -117,6 +118,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
viewMap := map[int]dataplexpb.EntryView{
|
||||
1: dataplexpb.EntryView_BASIC,
|
||||
2: dataplexpb.EntryView_FULL,
|
||||
3: dataplexpb.EntryView_CUSTOM,
|
||||
4: dataplexpb.EntryView_ALL,
|
||||
}
|
||||
name, _ := paramsMap["name"].(string)
|
||||
entry, _ := paramsMap["entry"].(string)
|
||||
view, _ := paramsMap["view"].(int)
|
||||
@@ -125,7 +132,19 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("can't convert aspectTypes to array of strings: %s", err)
|
||||
}
|
||||
aspectTypes := aspectTypeSlice.([]string)
|
||||
return source.LookupEntry(ctx, name, view, aspectTypes, entry)
|
||||
|
||||
req := &dataplexpb.LookupEntryRequest{
|
||||
Name: name,
|
||||
View: viewMap[view],
|
||||
AspectTypes: aspectTypes,
|
||||
Entry: entry,
|
||||
}
|
||||
|
||||
result, err := source.CatalogClient().LookupEntry(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -18,7 +18,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
||||
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"github.com/cenkalti/backoff/v5"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -43,7 +45,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
SearchAspectTypes(context.Context, string, int, string) ([]*dataplexpb.AspectType, error)
|
||||
CatalogClient() *dataplexapi.CatalogClient
|
||||
ProjectID() string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -98,11 +101,61 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Invoke the tool with the provided parameters
|
||||
paramsMap := params.AsMap()
|
||||
query, _ := paramsMap["query"].(string)
|
||||
pageSize, _ := paramsMap["pageSize"].(int)
|
||||
pageSize := int32(paramsMap["pageSize"].(int))
|
||||
orderBy, _ := paramsMap["orderBy"].(string)
|
||||
return source.SearchAspectTypes(ctx, query, pageSize, orderBy)
|
||||
|
||||
// Create SearchEntriesRequest with the provided parameters
|
||||
req := &dataplexpb.SearchEntriesRequest{
|
||||
Query: query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype",
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()),
|
||||
PageSize: pageSize,
|
||||
OrderBy: orderBy,
|
||||
SemanticSearch: true,
|
||||
}
|
||||
|
||||
// Perform the search using the CatalogClient - this will return an iterator
|
||||
it := source.CatalogClient().SearchEntries(ctx, req)
|
||||
if it == nil {
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID())
|
||||
}
|
||||
|
||||
// Create an instance of exponential backoff with default values for retrying GetAspectType calls
|
||||
// InitialInterval, RandomizationFactor, Multiplier, MaxInterval = 500 ms, 0.5, 1.5, 60 s
|
||||
getAspectBackOff := backoff.NewExponentialBackOff()
|
||||
|
||||
// Iterate through the search results and call GetAspectType for each result using the resource name
|
||||
var results []*dataplexpb.AspectType
|
||||
for {
|
||||
entry, err := it.Next()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
resourceName := entry.DataplexEntry.GetEntrySource().Resource
|
||||
getAspectTypeReq := &dataplexpb.GetAspectTypeRequest{
|
||||
Name: resourceName,
|
||||
}
|
||||
|
||||
operation := func() (*dataplexpb.AspectType, error) {
|
||||
aspectType, err := source.CatalogClient().GetAspectType(ctx, getAspectTypeReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err)
|
||||
}
|
||||
return aspectType, nil
|
||||
}
|
||||
|
||||
// Retry the GetAspectType operation with exponential backoff
|
||||
aspectType, err := backoff.Retry(ctx, operation, backoff.WithBackOff(getAspectBackOff))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get aspect type after retries for entry %q: %w", resourceName, err)
|
||||
}
|
||||
|
||||
results = append(results, aspectType)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -18,7 +18,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
||||
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -43,7 +44,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
SearchEntries(context.Context, string, int, string) ([]*dataplexpb.SearchEntriesResult, error)
|
||||
CatalogClient() *dataplexapi.CatalogClient
|
||||
ProjectID() string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -98,11 +100,34 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
query, _ := paramsMap["query"].(string)
|
||||
pageSize, _ := paramsMap["pageSize"].(int)
|
||||
pageSize := int32(paramsMap["pageSize"].(int))
|
||||
orderBy, _ := paramsMap["orderBy"].(string)
|
||||
return source.SearchEntries(ctx, query, pageSize, orderBy)
|
||||
|
||||
req := &dataplexpb.SearchEntriesRequest{
|
||||
Query: query,
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()),
|
||||
PageSize: pageSize,
|
||||
OrderBy: orderBy,
|
||||
SemanticSearch: true,
|
||||
}
|
||||
|
||||
it := source.CatalogClient().SearchEntries(ctx, req)
|
||||
if it == nil {
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID())
|
||||
}
|
||||
|
||||
var results []*dataplexpb.SearchEntriesResult
|
||||
for {
|
||||
entry, err := it.Next()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
results = append(results, entry)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -48,6 +48,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
AddDocuments(context.Context, string, any, bool) (map[string]any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -134,24 +135,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
|
||||
// Get collection path
|
||||
collectionPath, ok := mapParams[collectionPathKey].(string)
|
||||
if !ok || collectionPath == "" {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter", collectionPathKey)
|
||||
}
|
||||
|
||||
// Validate collection path
|
||||
if err := util.ValidateCollectionPath(collectionPath); err != nil {
|
||||
return nil, fmt.Errorf("invalid collection path: %w", err)
|
||||
}
|
||||
|
||||
// Get document data
|
||||
documentDataRaw, ok := mapParams[documentDataKey]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter", documentDataKey)
|
||||
}
|
||||
|
||||
// Convert the document data from JSON format to Firestore format
|
||||
// The client is passed to handle referenceValue types
|
||||
documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||
@@ -164,30 +161,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if val, ok := mapParams[returnDocumentDataKey].(bool); ok {
|
||||
returnData = val
|
||||
}
|
||||
|
||||
// Get the collection reference
|
||||
collection := source.FirestoreClient().Collection(collectionPath)
|
||||
|
||||
// Add the document to the collection
|
||||
docRef, writeResult, err := collection.Add(ctx, documentData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add document: %w", err)
|
||||
}
|
||||
|
||||
// Build the response
|
||||
response := map[string]any{
|
||||
"documentPath": docRef.Path,
|
||||
"createTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
|
||||
}
|
||||
|
||||
// Add document data if requested
|
||||
if returnData {
|
||||
// Convert the document data back to simple JSON format
|
||||
simplifiedData := util.FirestoreValueToJSON(documentData)
|
||||
response["documentData"] = simplifiedData
|
||||
}
|
||||
|
||||
return response, nil
|
||||
return source.AddDocuments(ctx, collectionPath, documentData, returnData)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -46,6 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
DeleteDocuments(context.Context, []string) ([]any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -104,7 +105,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected an array", documentPathsKey)
|
||||
}
|
||||
|
||||
if len(documentPathsRaw) == 0 {
|
||||
return nil, fmt.Errorf("'%s' parameter cannot be empty", documentPathsKey)
|
||||
}
|
||||
@@ -126,45 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid document path at index %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a BulkWriter to handle multiple deletions efficiently
|
||||
bulkWriter := source.FirestoreClient().BulkWriter(ctx)
|
||||
|
||||
// Keep track of jobs for each document
|
||||
jobs := make([]*firestoreapi.BulkWriterJob, len(documentPaths))
|
||||
|
||||
// Add all delete operations to the BulkWriter
|
||||
for i, path := range documentPaths {
|
||||
docRef := source.FirestoreClient().Doc(path)
|
||||
job, err := bulkWriter.Delete(docRef)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add delete operation for document %q: %w", path, err)
|
||||
}
|
||||
jobs[i] = job
|
||||
}
|
||||
|
||||
// End the BulkWriter to execute all operations
|
||||
bulkWriter.End()
|
||||
|
||||
// Collect results
|
||||
results := make([]any, len(documentPaths))
|
||||
for i, job := range jobs {
|
||||
docData := make(map[string]any)
|
||||
docData["path"] = documentPaths[i]
|
||||
|
||||
// Wait for the job to complete and get the result
|
||||
_, err := job.Results()
|
||||
if err != nil {
|
||||
docData["success"] = false
|
||||
docData["error"] = err.Error()
|
||||
} else {
|
||||
docData["success"] = true
|
||||
}
|
||||
|
||||
results[i] = docData
|
||||
}
|
||||
|
||||
return results, nil
|
||||
return source.DeleteDocuments(ctx, documentPaths)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -46,6 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
GetDocuments(context.Context, []string) ([]any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -126,37 +127,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid document path at index %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create document references from paths
|
||||
docRefs := make([]*firestoreapi.DocumentRef, len(documentPaths))
|
||||
for i, path := range documentPaths {
|
||||
docRefs[i] = source.FirestoreClient().Doc(path)
|
||||
}
|
||||
|
||||
// Get all documents
|
||||
snapshots, err := source.FirestoreClient().GetAll(ctx, docRefs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get documents: %w", err)
|
||||
}
|
||||
|
||||
// Convert snapshots to response data
|
||||
results := make([]any, len(snapshots))
|
||||
for i, snapshot := range snapshots {
|
||||
docData := make(map[string]any)
|
||||
docData["path"] = documentPaths[i]
|
||||
docData["exists"] = snapshot.Exists()
|
||||
|
||||
if snapshot.Exists() {
|
||||
docData["data"] = snapshot.Data()
|
||||
docData["createTime"] = snapshot.CreateTime
|
||||
docData["updateTime"] = snapshot.UpdateTime
|
||||
docData["readTime"] = snapshot.ReadTime
|
||||
}
|
||||
|
||||
results[i] = docData
|
||||
}
|
||||
|
||||
return results, nil
|
||||
return source.GetDocuments(ctx, documentPaths)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -44,8 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirebaseRulesClient() *firebaserules.Service
|
||||
GetProjectId() string
|
||||
GetDatabaseId() string
|
||||
GetRules(context.Context) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -98,29 +97,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the latest release for Firestore
|
||||
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", source.GetProjectId(), source.GetDatabaseId())
|
||||
release, err := source.FirebaseRulesClient().Projects.Releases.Get(releaseName).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get latest Firestore release: %w", err)
|
||||
}
|
||||
|
||||
if release.RulesetName == "" {
|
||||
return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", source.GetProjectId(), source.GetDatabaseId())
|
||||
}
|
||||
|
||||
// Get the ruleset content
|
||||
ruleset, err := source.FirebaseRulesClient().Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get ruleset content: %w", err)
|
||||
}
|
||||
|
||||
if ruleset.Source == nil || len(ruleset.Source.Files) == 0 {
|
||||
return nil, fmt.Errorf("no rules files found in ruleset")
|
||||
}
|
||||
|
||||
return ruleset, nil
|
||||
return source.GetRules(ctx)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -46,6 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
ListCollections(context.Context, string) ([]any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -102,47 +103,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
mapParams := params.AsMap()
|
||||
|
||||
var collectionRefs []*firestoreapi.CollectionRef
|
||||
|
||||
// Check if parentPath is provided
|
||||
parentPath, hasParent := mapParams[parentPathKey].(string)
|
||||
|
||||
if hasParent && parentPath != "" {
|
||||
parentPath, _ := mapParams[parentPathKey].(string)
|
||||
if parentPath != "" {
|
||||
// Validate parent document path
|
||||
if err := util.ValidateDocumentPath(parentPath); err != nil {
|
||||
return nil, fmt.Errorf("invalid parent document path: %w", err)
|
||||
}
|
||||
|
||||
// List subcollections of the specified document
|
||||
docRef := source.FirestoreClient().Doc(parentPath)
|
||||
collectionRefs, err = docRef.Collections(ctx).GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list subcollections of document %q: %w", parentPath, err)
|
||||
}
|
||||
} else {
|
||||
// List root collections
|
||||
collectionRefs, err = source.FirestoreClient().Collections(ctx).GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list root collections: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert collection references to response data
|
||||
results := make([]any, len(collectionRefs))
|
||||
for i, collRef := range collectionRefs {
|
||||
collData := make(map[string]any)
|
||||
collData["id"] = collRef.ID
|
||||
collData["path"] = collRef.Path
|
||||
|
||||
// If this is a subcollection, include parent information
|
||||
if collRef.Parent != nil {
|
||||
collData["parent"] = collRef.Parent.Path
|
||||
}
|
||||
|
||||
results[i] = collData
|
||||
}
|
||||
|
||||
return results, nil
|
||||
return source.ListCollections(ctx, parentPath)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -36,27 +36,6 @@ const (
|
||||
defaultLimit = 100
|
||||
)
|
||||
|
||||
// Firestore operators
|
||||
var validOperators = map[string]bool{
|
||||
"<": true,
|
||||
"<=": true,
|
||||
">": true,
|
||||
">=": true,
|
||||
"==": true,
|
||||
"!=": true,
|
||||
"array-contains": true,
|
||||
"array-contains-any": true,
|
||||
"in": true,
|
||||
"not-in": true,
|
||||
}
|
||||
|
||||
// Error messages
|
||||
const (
|
||||
errFilterParseFailed = "failed to parse filters: %w"
|
||||
errQueryExecutionFailed = "failed to execute query: %w"
|
||||
errLimitParseFailed = "failed to parse limit value '%s': %w"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
@@ -74,6 +53,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
// compatibleSource defines the interface for sources that can provide a Firestore client
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
BuildQuery(string, firestoreapi.EntityFilter, []string, string, firestoreapi.Direction, int, bool) (*firestoreapi.Query, error)
|
||||
ExecuteQuery(context.Context, *firestoreapi.Query, bool) (any, error)
|
||||
}
|
||||
|
||||
// Config represents the configuration for the Firestore query tool
|
||||
@@ -139,15 +120,6 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
// SimplifiedFilter represents the simplified filter format
|
||||
type SimplifiedFilter struct {
|
||||
And []SimplifiedFilter `json:"and,omitempty"`
|
||||
Or []SimplifiedFilter `json:"or,omitempty"`
|
||||
Field string `json:"field,omitempty"`
|
||||
Op string `json:"op,omitempty"`
|
||||
Value interface{} `json:"value,omitempty"`
|
||||
}
|
||||
|
||||
// OrderByConfig represents ordering configuration
|
||||
type OrderByConfig struct {
|
||||
Field string `json:"field"`
|
||||
@@ -162,20 +134,27 @@ func (o *OrderByConfig) GetDirection() firestoreapi.Direction {
|
||||
return firestoreapi.Asc
|
||||
}
|
||||
|
||||
// QueryResult represents a document result from the query
|
||||
type QueryResult struct {
|
||||
ID string `json:"id"`
|
||||
Path string `json:"path"`
|
||||
Data map[string]any `json:"data"`
|
||||
CreateTime interface{} `json:"createTime,omitempty"`
|
||||
UpdateTime interface{} `json:"updateTime,omitempty"`
|
||||
ReadTime interface{} `json:"readTime,omitempty"`
|
||||
// SimplifiedFilter represents the simplified filter format
|
||||
type SimplifiedFilter struct {
|
||||
And []SimplifiedFilter `json:"and,omitempty"`
|
||||
Or []SimplifiedFilter `json:"or,omitempty"`
|
||||
Field string `json:"field,omitempty"`
|
||||
Op string `json:"op,omitempty"`
|
||||
Value interface{} `json:"value,omitempty"`
|
||||
}
|
||||
|
||||
// QueryResponse represents the full response including optional metrics
|
||||
type QueryResponse struct {
|
||||
Documents []QueryResult `json:"documents"`
|
||||
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
|
||||
// Firestore operators
|
||||
var validOperators = map[string]bool{
|
||||
"<": true,
|
||||
"<=": true,
|
||||
">": true,
|
||||
">=": true,
|
||||
"==": true,
|
||||
"!=": true,
|
||||
"array-contains": true,
|
||||
"array-contains-any": true,
|
||||
"in": true,
|
||||
"not-in": true,
|
||||
}
|
||||
|
||||
// Invoke executes the Firestore query based on the provided parameters
|
||||
@@ -184,34 +163,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
// Process collection path with template substitution
|
||||
collectionPath, err := parameters.PopulateTemplate("collectionPath", t.CollectionPath, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process collection path: %w", err)
|
||||
}
|
||||
|
||||
// Build the query
|
||||
query, err := t.buildQuery(source, collectionPath, paramsMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the query and return results
|
||||
return t.executeQuery(ctx, query)
|
||||
}
|
||||
|
||||
// buildQuery constructs the Firestore query from parameters
|
||||
func (t Tool) buildQuery(source compatibleSource, collectionPath string, params map[string]any) (*firestoreapi.Query, error) {
|
||||
collection := source.FirestoreClient().Collection(collectionPath)
|
||||
query := collection.Query
|
||||
|
||||
var filter firestoreapi.EntityFilter
|
||||
// Process and apply filters if template is provided
|
||||
if t.Filters != "" {
|
||||
// Apply template substitution to filters
|
||||
filtersJSON, err := parameters.PopulateTemplateWithJSON("filters", t.Filters, params)
|
||||
filtersJSON, err := parameters.PopulateTemplateWithJSON("filters", t.Filters, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process filters template: %w", err)
|
||||
}
|
||||
@@ -219,48 +182,43 @@ func (t Tool) buildQuery(source compatibleSource, collectionPath string, params
|
||||
// Parse the simplified filter format
|
||||
var simplifiedFilter SimplifiedFilter
|
||||
if err := json.Unmarshal([]byte(filtersJSON), &simplifiedFilter); err != nil {
|
||||
return nil, fmt.Errorf(errFilterParseFailed, err)
|
||||
return nil, fmt.Errorf("failed to parse filters: %w", err)
|
||||
}
|
||||
|
||||
// Convert simplified filter to Firestore filter
|
||||
if filter := t.convertToFirestoreFilter(source, simplifiedFilter); filter != nil {
|
||||
query = query.WhereEntity(filter)
|
||||
}
|
||||
filter = t.convertToFirestoreFilter(source, simplifiedFilter)
|
||||
}
|
||||
|
||||
// Process select fields
|
||||
selectFields, err := t.processSelectFields(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(selectFields) > 0 {
|
||||
query = query.Select(selectFields...)
|
||||
}
|
||||
|
||||
// Process and apply ordering
|
||||
orderBy, err := t.getOrderBy(params)
|
||||
orderBy, err := t.getOrderBy(paramsMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if orderBy != nil {
|
||||
query = query.OrderBy(orderBy.Field, orderBy.GetDirection())
|
||||
// Process select fields
|
||||
selectFields, err := t.processSelectFields(paramsMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Process and apply limit
|
||||
limit, err := t.getLimit(params)
|
||||
limit, err := t.getLimit(paramsMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
query = query.Limit(limit)
|
||||
|
||||
// Apply analyze options if enabled
|
||||
if t.AnalyzeQuery {
|
||||
query = query.WithRunOptions(firestoreapi.ExplainOptions{
|
||||
Analyze: true,
|
||||
})
|
||||
// prevent panic when accessing orderBy incase it is nil
|
||||
var orderByField string
|
||||
var orderByDirection firestoreapi.Direction
|
||||
if orderBy != nil {
|
||||
orderByField = orderBy.Field
|
||||
orderByDirection = orderBy.GetDirection()
|
||||
}
|
||||
|
||||
return &query, nil
|
||||
// Build the query
|
||||
query, err := source.BuildQuery(collectionPath, filter, selectFields, orderByField, orderByDirection, limit, t.AnalyzeQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Execute the query and return results
|
||||
return source.ExecuteQuery(ctx, query, t.AnalyzeQuery)
|
||||
}
|
||||
|
||||
// convertToFirestoreFilter converts simplified filter format to Firestore EntityFilter
|
||||
@@ -409,7 +367,7 @@ func (t Tool) getLimit(params map[string]any) (int, error) {
|
||||
if processedValue != "" {
|
||||
parsedLimit, err := strconv.Atoi(processedValue)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf(errLimitParseFailed, processedValue, err)
|
||||
return 0, fmt.Errorf("failed to parse limit value '%s': %w", processedValue, err)
|
||||
}
|
||||
limit = parsedLimit
|
||||
}
|
||||
@@ -417,78 +375,6 @@ func (t Tool) getLimit(params map[string]any) (int, error) {
|
||||
return limit, nil
|
||||
}
|
||||
|
||||
// executeQuery runs the query and formats the results
|
||||
func (t Tool) executeQuery(ctx context.Context, query *firestoreapi.Query) (any, error) {
|
||||
docIterator := query.Documents(ctx)
|
||||
docs, err := docIterator.GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(errQueryExecutionFailed, err)
|
||||
}
|
||||
|
||||
// Convert results to structured format
|
||||
results := make([]QueryResult, len(docs))
|
||||
for i, doc := range docs {
|
||||
results[i] = QueryResult{
|
||||
ID: doc.Ref.ID,
|
||||
Path: doc.Ref.Path,
|
||||
Data: doc.Data(),
|
||||
CreateTime: doc.CreateTime,
|
||||
UpdateTime: doc.UpdateTime,
|
||||
ReadTime: doc.ReadTime,
|
||||
}
|
||||
}
|
||||
|
||||
// Return with explain metrics if requested
|
||||
if t.AnalyzeQuery {
|
||||
explainMetrics, err := t.getExplainMetrics(docIterator)
|
||||
if err == nil && explainMetrics != nil {
|
||||
response := QueryResponse{
|
||||
Documents: results,
|
||||
ExplainMetrics: explainMetrics,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// getExplainMetrics extracts explain metrics from the query iterator
|
||||
func (t Tool) getExplainMetrics(docIterator *firestoreapi.DocumentIterator) (map[string]any, error) {
|
||||
explainMetrics, err := docIterator.ExplainMetrics()
|
||||
if err != nil || explainMetrics == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metricsData := make(map[string]any)
|
||||
|
||||
// Add plan summary if available
|
||||
if explainMetrics.PlanSummary != nil {
|
||||
planSummary := make(map[string]any)
|
||||
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
|
||||
metricsData["planSummary"] = planSummary
|
||||
}
|
||||
|
||||
// Add execution stats if available
|
||||
if explainMetrics.ExecutionStats != nil {
|
||||
executionStats := make(map[string]any)
|
||||
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
|
||||
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
|
||||
|
||||
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
|
||||
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
|
||||
}
|
||||
|
||||
if explainMetrics.ExecutionStats.DebugStats != nil {
|
||||
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
|
||||
}
|
||||
|
||||
metricsData["executionStats"] = executionStats
|
||||
}
|
||||
|
||||
return metricsData, nil
|
||||
}
|
||||
|
||||
// ParseParams parses and validates input parameters
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
return parameters.ParseParams(t.Parameters, data, claims)
|
||||
|
||||
@@ -69,7 +69,6 @@ const (
|
||||
errInvalidOperator = "unsupported operator: %s. Valid operators are: %v"
|
||||
errMissingFilterValue = "no value specified for filter on field '%s'"
|
||||
errOrderByParseFailed = "failed to parse orderBy: %w"
|
||||
errQueryExecutionFailed = "failed to execute query: %w"
|
||||
errTooManyFilters = "too many filters provided: %d (maximum: %d)"
|
||||
)
|
||||
|
||||
@@ -90,6 +89,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
// compatibleSource defines the interface for sources that can provide a Firestore client
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
BuildQuery(string, firestoreapi.EntityFilter, []string, string, firestoreapi.Direction, int, bool) (*firestoreapi.Query, error)
|
||||
ExecuteQuery(context.Context, *firestoreapi.Query, bool) (any, error)
|
||||
}
|
||||
|
||||
// Config represents the configuration for the Firestore query collection tool
|
||||
@@ -228,22 +229,6 @@ func (o *OrderByConfig) GetDirection() firestoreapi.Direction {
|
||||
return firestoreapi.Asc
|
||||
}
|
||||
|
||||
// QueryResult represents a document result from the query
|
||||
type QueryResult struct {
|
||||
ID string `json:"id"`
|
||||
Path string `json:"path"`
|
||||
Data map[string]any `json:"data"`
|
||||
CreateTime interface{} `json:"createTime,omitempty"`
|
||||
UpdateTime interface{} `json:"updateTime,omitempty"`
|
||||
ReadTime interface{} `json:"readTime,omitempty"`
|
||||
}
|
||||
|
||||
// QueryResponse represents the full response including optional metrics
|
||||
type QueryResponse struct {
|
||||
Documents []QueryResult `json:"documents"`
|
||||
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
|
||||
}
|
||||
|
||||
// Invoke executes the Firestore query based on the provided parameters
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
@@ -257,14 +242,37 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var filter firestoreapi.EntityFilter
|
||||
// Apply filters
|
||||
if len(queryParams.Filters) > 0 {
|
||||
filterConditions := make([]firestoreapi.EntityFilter, 0, len(queryParams.Filters))
|
||||
for _, filter := range queryParams.Filters {
|
||||
filterConditions = append(filterConditions, firestoreapi.PropertyFilter{
|
||||
Path: filter.Field,
|
||||
Operator: filter.Op,
|
||||
Value: filter.Value,
|
||||
})
|
||||
}
|
||||
|
||||
filter = firestoreapi.AndFilter{
|
||||
Filters: filterConditions,
|
||||
}
|
||||
}
|
||||
|
||||
// prevent panic incase queryParams.OrderBy is nil
|
||||
var orderByField string
|
||||
var orderByDirection firestoreapi.Direction
|
||||
if queryParams.OrderBy != nil {
|
||||
orderByField = queryParams.OrderBy.Field
|
||||
orderByDirection = queryParams.OrderBy.GetDirection()
|
||||
}
|
||||
|
||||
// Build the query
|
||||
query, err := t.buildQuery(source, queryParams)
|
||||
query, err := source.BuildQuery(queryParams.CollectionPath, filter, nil, orderByField, orderByDirection, queryParams.Limit, queryParams.AnalyzeQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the query and return results
|
||||
return t.executeQuery(ctx, query, queryParams.AnalyzeQuery)
|
||||
return source.ExecuteQuery(ctx, query, queryParams.AnalyzeQuery)
|
||||
}
|
||||
|
||||
// queryParameters holds all parsed query parameters
|
||||
@@ -380,122 +388,6 @@ func (t Tool) parseOrderBy(orderByRaw interface{}) (*OrderByConfig, error) {
|
||||
return &orderBy, nil
|
||||
}
|
||||
|
||||
// buildQuery constructs the Firestore query from parameters
|
||||
func (t Tool) buildQuery(source compatibleSource, params *queryParameters) (*firestoreapi.Query, error) {
|
||||
collection := source.FirestoreClient().Collection(params.CollectionPath)
|
||||
query := collection.Query
|
||||
|
||||
// Apply filters
|
||||
if len(params.Filters) > 0 {
|
||||
filterConditions := make([]firestoreapi.EntityFilter, 0, len(params.Filters))
|
||||
for _, filter := range params.Filters {
|
||||
filterConditions = append(filterConditions, firestoreapi.PropertyFilter{
|
||||
Path: filter.Field,
|
||||
Operator: filter.Op,
|
||||
Value: filter.Value,
|
||||
})
|
||||
}
|
||||
|
||||
query = query.WhereEntity(firestoreapi.AndFilter{
|
||||
Filters: filterConditions,
|
||||
})
|
||||
}
|
||||
|
||||
// Apply ordering
|
||||
if params.OrderBy != nil {
|
||||
query = query.OrderBy(params.OrderBy.Field, params.OrderBy.GetDirection())
|
||||
}
|
||||
|
||||
// Apply limit
|
||||
query = query.Limit(params.Limit)
|
||||
|
||||
// Apply analyze options
|
||||
if params.AnalyzeQuery {
|
||||
query = query.WithRunOptions(firestoreapi.ExplainOptions{
|
||||
Analyze: true,
|
||||
})
|
||||
}
|
||||
|
||||
return &query, nil
|
||||
}
|
||||
|
||||
// executeQuery runs the query and formats the results
|
||||
func (t Tool) executeQuery(ctx context.Context, query *firestoreapi.Query, analyzeQuery bool) (any, error) {
|
||||
docIterator := query.Documents(ctx)
|
||||
docs, err := docIterator.GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(errQueryExecutionFailed, err)
|
||||
}
|
||||
|
||||
// Convert results to structured format
|
||||
results := make([]QueryResult, len(docs))
|
||||
for i, doc := range docs {
|
||||
results[i] = QueryResult{
|
||||
ID: doc.Ref.ID,
|
||||
Path: doc.Ref.Path,
|
||||
Data: doc.Data(),
|
||||
CreateTime: doc.CreateTime,
|
||||
UpdateTime: doc.UpdateTime,
|
||||
ReadTime: doc.ReadTime,
|
||||
}
|
||||
}
|
||||
|
||||
// Return with explain metrics if requested
|
||||
if analyzeQuery {
|
||||
explainMetrics, err := t.getExplainMetrics(docIterator)
|
||||
if err == nil && explainMetrics != nil {
|
||||
response := QueryResponse{
|
||||
Documents: results,
|
||||
ExplainMetrics: explainMetrics,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Return just the documents
|
||||
resultsAny := make([]any, len(results))
|
||||
for i, r := range results {
|
||||
resultsAny[i] = r
|
||||
}
|
||||
return resultsAny, nil
|
||||
}
|
||||
|
||||
// getExplainMetrics extracts explain metrics from the query iterator
|
||||
func (t Tool) getExplainMetrics(docIterator *firestoreapi.DocumentIterator) (map[string]any, error) {
|
||||
explainMetrics, err := docIterator.ExplainMetrics()
|
||||
if err != nil || explainMetrics == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metricsData := make(map[string]any)
|
||||
|
||||
// Add plan summary if available
|
||||
if explainMetrics.PlanSummary != nil {
|
||||
planSummary := make(map[string]any)
|
||||
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
|
||||
metricsData["planSummary"] = planSummary
|
||||
}
|
||||
|
||||
// Add execution stats if available
|
||||
if explainMetrics.ExecutionStats != nil {
|
||||
executionStats := make(map[string]any)
|
||||
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
|
||||
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
|
||||
|
||||
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
|
||||
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
|
||||
}
|
||||
|
||||
if explainMetrics.ExecutionStats.DebugStats != nil {
|
||||
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
|
||||
}
|
||||
|
||||
metricsData["executionStats"] = executionStats
|
||||
}
|
||||
|
||||
return metricsData, nil
|
||||
}
|
||||
|
||||
// ParseParams parses and validates input parameters
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
return parameters.ParseParams(t.Parameters, data, claims)
|
||||
|
||||
@@ -50,6 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
UpdateDocument(context.Context, string, []firestoreapi.Update, any, bool) (map[string]any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -177,23 +178,10 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get return document data flag
|
||||
returnData := false
|
||||
if val, ok := mapParams[returnDocumentDataKey].(bool); ok {
|
||||
returnData = val
|
||||
}
|
||||
|
||||
// Get the document reference
|
||||
docRef := source.FirestoreClient().Doc(documentPath)
|
||||
|
||||
// Prepare update data
|
||||
var writeResult *firestoreapi.WriteResult
|
||||
var writeErr error
|
||||
|
||||
// Use selective field update with update mask
|
||||
updates := make([]firestoreapi.Update, 0, len(updatePaths))
|
||||
var documentData any
|
||||
if len(updatePaths) > 0 {
|
||||
// Use selective field update with update mask
|
||||
updates := make([]firestoreapi.Update, 0, len(updatePaths))
|
||||
|
||||
// Convert document data without delete markers
|
||||
dataMap, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||
@@ -220,41 +208,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
|
||||
writeResult, writeErr = docRef.Update(ctx, updates)
|
||||
} else {
|
||||
// Update all fields in the document data (merge)
|
||||
documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||
documentData, err = util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert document data: %w", err)
|
||||
}
|
||||
writeResult, writeErr = docRef.Set(ctx, documentData, firestoreapi.MergeAll)
|
||||
}
|
||||
|
||||
if writeErr != nil {
|
||||
return nil, fmt.Errorf("failed to update document: %w", writeErr)
|
||||
// Get return document data flag
|
||||
returnData := false
|
||||
if val, ok := mapParams[returnDocumentDataKey].(bool); ok {
|
||||
returnData = val
|
||||
}
|
||||
|
||||
// Build the response
|
||||
response := map[string]any{
|
||||
"documentPath": docRef.Path,
|
||||
"updateTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
|
||||
}
|
||||
|
||||
// Add document data if requested
|
||||
if returnData {
|
||||
// Fetch the updated document to return the current state
|
||||
snapshot, err := docRef.Get(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
|
||||
}
|
||||
|
||||
// Convert the document data to simple JSON format
|
||||
simplifiedData := util.FirestoreValueToJSON(snapshot.Data())
|
||||
response["documentData"] = simplifiedData
|
||||
}
|
||||
|
||||
return response, nil
|
||||
return source.UpdateDocument(ctx, documentPath, updates, documentData, returnData)
|
||||
}
|
||||
|
||||
// getFieldValue retrieves a value from a nested map using a dot-separated path
|
||||
|
||||
@@ -17,7 +17,6 @@ package firestorevalidaterules
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
@@ -50,7 +49,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirebaseRulesClient() *firebaserules.Service
|
||||
GetProjectId() string
|
||||
ValidateRules(context.Context, string) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -107,30 +106,6 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
// Issue represents a validation issue in the rules
|
||||
type Issue struct {
|
||||
SourcePosition SourcePosition `json:"sourcePosition"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
}
|
||||
|
||||
// SourcePosition represents the location of an issue in the source
|
||||
type SourcePosition struct {
|
||||
FileName string `json:"fileName,omitempty"`
|
||||
Line int64 `json:"line"` // 1-based
|
||||
Column int64 `json:"column"` // 1-based
|
||||
CurrentOffset int64 `json:"currentOffset"` // 0-based, inclusive start
|
||||
EndOffset int64 `json:"endOffset"` // 0-based, exclusive end
|
||||
}
|
||||
|
||||
// ValidationResult represents the result of rules validation
|
||||
type ValidationResult struct {
|
||||
Valid bool `json:"valid"`
|
||||
IssueCount int `json:"issueCount"`
|
||||
FormattedIssues string `json:"formattedIssues,omitempty"`
|
||||
RawIssues []Issue `json:"rawIssues,omitempty"`
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
@@ -144,114 +119,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if !ok || sourceParam == "" {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter", sourceKey)
|
||||
}
|
||||
|
||||
// Create test request
|
||||
testRequest := &firebaserules.TestRulesetRequest{
|
||||
Source: &firebaserules.Source{
|
||||
Files: []*firebaserules.File{
|
||||
{
|
||||
Name: "firestore.rules",
|
||||
Content: sourceParam,
|
||||
},
|
||||
},
|
||||
},
|
||||
// We don't need test cases for validation only
|
||||
TestSuite: &firebaserules.TestSuite{
|
||||
TestCases: []*firebaserules.TestCase{},
|
||||
},
|
||||
}
|
||||
|
||||
// Call the test API
|
||||
projectName := fmt.Sprintf("projects/%s", source.GetProjectId())
|
||||
response, err := source.FirebaseRulesClient().Projects.Test(projectName, testRequest).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate rules: %w", err)
|
||||
}
|
||||
|
||||
// Process the response
|
||||
result := t.processValidationResponse(response, sourceParam)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t Tool) processValidationResponse(response *firebaserules.TestRulesetResponse, source string) ValidationResult {
|
||||
if len(response.Issues) == 0 {
|
||||
return ValidationResult{
|
||||
Valid: true,
|
||||
IssueCount: 0,
|
||||
FormattedIssues: "✓ No errors detected. Rules are valid.",
|
||||
}
|
||||
}
|
||||
|
||||
// Convert issues to our format
|
||||
issues := make([]Issue, len(response.Issues))
|
||||
for i, issue := range response.Issues {
|
||||
issues[i] = Issue{
|
||||
Description: issue.Description,
|
||||
Severity: issue.Severity,
|
||||
SourcePosition: SourcePosition{
|
||||
FileName: issue.SourcePosition.FileName,
|
||||
Line: issue.SourcePosition.Line,
|
||||
Column: issue.SourcePosition.Column,
|
||||
CurrentOffset: issue.SourcePosition.CurrentOffset,
|
||||
EndOffset: issue.SourcePosition.EndOffset,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Format issues
|
||||
formattedIssues := t.formatRulesetIssues(issues, source)
|
||||
|
||||
return ValidationResult{
|
||||
Valid: false,
|
||||
IssueCount: len(issues),
|
||||
FormattedIssues: formattedIssues,
|
||||
RawIssues: issues,
|
||||
}
|
||||
}
|
||||
|
||||
// formatRulesetIssues formats validation issues into a human-readable string with code snippets
|
||||
func (t Tool) formatRulesetIssues(issues []Issue, rulesSource string) string {
|
||||
sourceLines := strings.Split(rulesSource, "\n")
|
||||
var formattedOutput []string
|
||||
|
||||
formattedOutput = append(formattedOutput, fmt.Sprintf("Found %d issue(s) in rules source:\n", len(issues)))
|
||||
|
||||
for _, issue := range issues {
|
||||
issueString := fmt.Sprintf("%s: %s [Ln %d, Col %d]",
|
||||
issue.Severity,
|
||||
issue.Description,
|
||||
issue.SourcePosition.Line,
|
||||
issue.SourcePosition.Column)
|
||||
|
||||
if issue.SourcePosition.Line > 0 {
|
||||
lineIndex := int(issue.SourcePosition.Line - 1) // 0-based index
|
||||
if lineIndex >= 0 && lineIndex < len(sourceLines) {
|
||||
errorLine := sourceLines[lineIndex]
|
||||
issueString += fmt.Sprintf("\n```\n%s", errorLine)
|
||||
|
||||
// Add carets if we have column and offset information
|
||||
if issue.SourcePosition.Column > 0 &&
|
||||
issue.SourcePosition.CurrentOffset >= 0 &&
|
||||
issue.SourcePosition.EndOffset > issue.SourcePosition.CurrentOffset {
|
||||
|
||||
startColumn := int(issue.SourcePosition.Column - 1) // 0-based
|
||||
errorTokenLength := int(issue.SourcePosition.EndOffset - issue.SourcePosition.CurrentOffset)
|
||||
|
||||
if startColumn >= 0 && errorTokenLength > 0 && startColumn <= len(errorLine) {
|
||||
padding := strings.Repeat(" ", startColumn)
|
||||
carets := strings.Repeat("^", errorTokenLength)
|
||||
issueString += fmt.Sprintf("\n%s%s", padding, carets)
|
||||
}
|
||||
}
|
||||
issueString += "\n```"
|
||||
}
|
||||
}
|
||||
|
||||
formattedOutput = append(formattedOutput, issueString)
|
||||
}
|
||||
|
||||
return strings.Join(formattedOutput, "\n\n")
|
||||
return source.ValidateRules(ctx, sourceParam)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -28,13 +28,13 @@ import (
|
||||
// JSONToFirestoreValue converts a JSON value with type information to a Firestore-compatible value
|
||||
// The input should be a map with a single key indicating the type (e.g., "stringValue", "integerValue")
|
||||
// If a client is provided, referenceValue types will be converted to *firestore.DocumentRef
|
||||
func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interface{}, error) {
|
||||
func JSONToFirestoreValue(value any, client *firestore.Client) (any, error) {
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case map[string]interface{}:
|
||||
case map[string]any:
|
||||
// Check for typed values
|
||||
if len(v) == 1 {
|
||||
for key, val := range v {
|
||||
@@ -92,7 +92,7 @@ func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interfac
|
||||
return nil, fmt.Errorf("timestamp value must be a string")
|
||||
case "geoPointValue":
|
||||
// Convert to LatLng
|
||||
if geoMap, ok := val.(map[string]interface{}); ok {
|
||||
if geoMap, ok := val.(map[string]any); ok {
|
||||
lat, latOk := geoMap["latitude"].(float64)
|
||||
lng, lngOk := geoMap["longitude"].(float64)
|
||||
if latOk && lngOk {
|
||||
@@ -105,9 +105,9 @@ func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interfac
|
||||
return nil, fmt.Errorf("invalid geopoint value format")
|
||||
case "arrayValue":
|
||||
// Convert array
|
||||
if arrayMap, ok := val.(map[string]interface{}); ok {
|
||||
if values, ok := arrayMap["values"].([]interface{}); ok {
|
||||
result := make([]interface{}, len(values))
|
||||
if arrayMap, ok := val.(map[string]any); ok {
|
||||
if values, ok := arrayMap["values"].([]any); ok {
|
||||
result := make([]any, len(values))
|
||||
for i, item := range values {
|
||||
converted, err := JSONToFirestoreValue(item, client)
|
||||
if err != nil {
|
||||
@@ -121,9 +121,9 @@ func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interfac
|
||||
return nil, fmt.Errorf("invalid array value format")
|
||||
case "mapValue":
|
||||
// Convert map
|
||||
if mapMap, ok := val.(map[string]interface{}); ok {
|
||||
if fields, ok := mapMap["fields"].(map[string]interface{}); ok {
|
||||
result := make(map[string]interface{})
|
||||
if mapMap, ok := val.(map[string]any); ok {
|
||||
if fields, ok := mapMap["fields"].(map[string]any); ok {
|
||||
result := make(map[string]any)
|
||||
for k, v := range fields {
|
||||
converted, err := JSONToFirestoreValue(v, client)
|
||||
if err != nil {
|
||||
@@ -160,8 +160,8 @@ func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interfac
|
||||
}
|
||||
|
||||
// convertPlainMap converts a plain map to Firestore format
|
||||
func convertPlainMap(m map[string]interface{}, client *firestore.Client) (map[string]interface{}, error) {
|
||||
result := make(map[string]interface{})
|
||||
func convertPlainMap(m map[string]any, client *firestore.Client) (map[string]any, error) {
|
||||
result := make(map[string]any)
|
||||
for k, v := range m {
|
||||
converted, err := JSONToFirestoreValue(v, client)
|
||||
if err != nil {
|
||||
@@ -172,42 +172,6 @@ func convertPlainMap(m map[string]interface{}, client *firestore.Client) (map[st
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FirestoreValueToJSON converts a Firestore value to a simplified JSON representation
|
||||
// This removes type information and returns plain values
|
||||
func FirestoreValueToJSON(value interface{}) interface{} {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case time.Time:
|
||||
return v.Format(time.RFC3339Nano)
|
||||
case *latlng.LatLng:
|
||||
return map[string]interface{}{
|
||||
"latitude": v.Latitude,
|
||||
"longitude": v.Longitude,
|
||||
}
|
||||
case []byte:
|
||||
return base64.StdEncoding.EncodeToString(v)
|
||||
case []interface{}:
|
||||
result := make([]interface{}, len(v))
|
||||
for i, item := range v {
|
||||
result[i] = FirestoreValueToJSON(item)
|
||||
}
|
||||
return result
|
||||
case map[string]interface{}:
|
||||
result := make(map[string]interface{})
|
||||
for k, val := range v {
|
||||
result[k] = FirestoreValueToJSON(val)
|
||||
}
|
||||
return result
|
||||
case *firestore.DocumentRef:
|
||||
return v.Path
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// isValidDocumentPath checks if a string is a valid Firestore document path
|
||||
// Valid paths have an even number of segments (collection/doc/collection/doc...)
|
||||
func isValidDocumentPath(path string) bool {
|
||||
|
||||
@@ -312,40 +312,6 @@ func TestJSONToFirestoreValue_IntegerFromString(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirestoreValueToJSON_RoundTrip(t *testing.T) {
|
||||
// Test round-trip conversion
|
||||
original := map[string]interface{}{
|
||||
"name": "Test",
|
||||
"count": int64(42),
|
||||
"price": 19.99,
|
||||
"active": true,
|
||||
"tags": []interface{}{"tag1", "tag2"},
|
||||
"metadata": map[string]interface{}{
|
||||
"created": time.Now(),
|
||||
},
|
||||
"nullField": nil,
|
||||
}
|
||||
|
||||
// Convert to JSON representation
|
||||
jsonRepresentation := FirestoreValueToJSON(original)
|
||||
|
||||
// Verify types are simplified
|
||||
jsonMap, ok := jsonRepresentation.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("Expected map, got %T", jsonRepresentation)
|
||||
}
|
||||
|
||||
// Time should be converted to string
|
||||
metadata, ok := jsonMap["metadata"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("metadata should be a map, got %T", jsonMap["metadata"])
|
||||
}
|
||||
_, ok = metadata["created"].(string)
|
||||
if !ok {
|
||||
t.Errorf("created should be a string, got %T", metadata["created"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONToFirestoreValue_InvalidFormats(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -16,7 +16,9 @@ package http
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
@@ -52,7 +54,7 @@ type compatibleSource interface {
|
||||
HttpDefaultHeaders() map[string]string
|
||||
HttpBaseURL() string
|
||||
HttpQueryParams() map[string]string
|
||||
RunRequest(*http.Request) (any, error)
|
||||
Client() *http.Client
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -257,7 +259,29 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
for k, v := range allHeaders {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
return source.RunRequest(req)
|
||||
|
||||
// Make request and fetch response
|
||||
resp, err := source.Client().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making HTTP request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var body []byte
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var data any
|
||||
if err = json.Unmarshal(body, &data); err != nil {
|
||||
// if unable to unmarshal data, return result as string.
|
||||
return string(body), nil
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
// Copyright 2026 Google LLC
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package serverlessspark
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -23,13 +23,13 @@ import (
|
||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
)
|
||||
|
||||
var batchFullNameRegex = regexp.MustCompile(`projects/(?P<project>[^/]+)/locations/(?P<location>[^/]+)/batches/(?P<batch_id>[^/]+)`)
|
||||
|
||||
const (
|
||||
logTimeBufferBefore = 1 * time.Minute
|
||||
logTimeBufferAfter = 10 * time.Minute
|
||||
)
|
||||
|
||||
var batchFullNameRegex = regexp.MustCompile(`projects/(?P<project>[^/]+)/locations/(?P<location>[^/]+)/batches/(?P<batch_id>[^/]+)`)
|
||||
|
||||
// Extract BatchDetails extracts the project ID, location, and batch ID from a fully qualified batch name.
|
||||
func ExtractBatchDetails(batchName string) (projectID, location, batchID string, err error) {
|
||||
matches := batchFullNameRegex.FindStringSubmatch(batchName)
|
||||
@@ -39,6 +39,26 @@ func ExtractBatchDetails(batchName string) (projectID, location, batchID string,
|
||||
return matches[1], matches[2], matches[3], nil
|
||||
}
|
||||
|
||||
// BatchConsoleURLFromProto builds a URL to the Google Cloud Console linking to the batch summary page.
|
||||
func BatchConsoleURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
|
||||
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return BatchConsoleURL(projectID, location, batchID), nil
|
||||
}
|
||||
|
||||
// BatchLogsURLFromProto builds a URL to the Google Cloud Console showing Cloud Logging for the given batch and time range.
|
||||
func BatchLogsURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
|
||||
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
createTime := batchPb.GetCreateTime().AsTime()
|
||||
stateTime := batchPb.GetStateTime().AsTime()
|
||||
return BatchLogsURL(projectID, location, batchID, createTime, stateTime), nil
|
||||
}
|
||||
|
||||
// BatchConsoleURL builds a URL to the Google Cloud Console linking to the batch summary page.
|
||||
func BatchConsoleURL(projectID, location, batchID string) string {
|
||||
return fmt.Sprintf("https://console.cloud.google.com/dataproc/batches/%s/%s/summary?project=%s", location, batchID, projectID)
|
||||
@@ -69,23 +89,3 @@ resource.labels.batch_id="%s"`
|
||||
|
||||
return "https://console.cloud.google.com/logs/viewer?" + v.Encode()
|
||||
}
|
||||
|
||||
// BatchConsoleURLFromProto builds a URL to the Google Cloud Console linking to the batch summary page.
|
||||
func BatchConsoleURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
|
||||
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return BatchConsoleURL(projectID, location, batchID), nil
|
||||
}
|
||||
|
||||
// BatchLogsURLFromProto builds a URL to the Google Cloud Console showing Cloud Logging for the given batch and time range.
|
||||
func BatchLogsURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
|
||||
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
createTime := batchPb.GetCreateTime().AsTime()
|
||||
stateTime := batchPb.GetStateTime().AsTime()
|
||||
return BatchLogsURL(projectID, location, batchID, createTime, stateTime), nil
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
// Copyright 2026 Google LLC
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -12,20 +12,19 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package serverlessspark_test
|
||||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
func TestExtractBatchDetails_Success(t *testing.T) {
|
||||
batchName := "projects/my-project/locations/us-central1/batches/my-batch"
|
||||
projectID, location, batchID, err := serverlessspark.ExtractBatchDetails(batchName)
|
||||
projectID, location, batchID, err := ExtractBatchDetails(batchName)
|
||||
if err != nil {
|
||||
t.Errorf("ExtractBatchDetails() error = %v, want no error", err)
|
||||
return
|
||||
@@ -46,7 +45,7 @@ func TestExtractBatchDetails_Success(t *testing.T) {
|
||||
|
||||
func TestExtractBatchDetails_Failure(t *testing.T) {
|
||||
batchName := "invalid-name"
|
||||
_, _, _, err := serverlessspark.ExtractBatchDetails(batchName)
|
||||
_, _, _, err := ExtractBatchDetails(batchName)
|
||||
wantErr := "failed to parse batch name: invalid-name"
|
||||
if err == nil || err.Error() != wantErr {
|
||||
t.Errorf("ExtractBatchDetails() error = %v, want %v", err, wantErr)
|
||||
@@ -54,7 +53,7 @@ func TestExtractBatchDetails_Failure(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBatchConsoleURL(t *testing.T) {
|
||||
got := serverlessspark.BatchConsoleURL("my-project", "us-central1", "my-batch")
|
||||
got := BatchConsoleURL("my-project", "us-central1", "my-batch")
|
||||
want := "https://console.cloud.google.com/dataproc/batches/us-central1/my-batch/summary?project=my-project"
|
||||
if got != want {
|
||||
t.Errorf("BatchConsoleURL() = %v, want %v", got, want)
|
||||
@@ -64,7 +63,7 @@ func TestBatchConsoleURL(t *testing.T) {
|
||||
func TestBatchLogsURL(t *testing.T) {
|
||||
startTime := time.Date(2025, 10, 1, 5, 0, 0, 0, time.UTC)
|
||||
endTime := time.Date(2025, 10, 1, 6, 0, 0, 0, time.UTC)
|
||||
got := serverlessspark.BatchLogsURL("my-project", "us-central1", "my-batch", startTime, endTime)
|
||||
got := BatchLogsURL("my-project", "us-central1", "my-batch", startTime, endTime)
|
||||
want := "https://console.cloud.google.com/logs/viewer?advancedFilter=" +
|
||||
"resource.type%3D%22cloud_dataproc_batch%22" +
|
||||
"%0Aresource.labels.project_id%3D%22my-project%22" +
|
||||
@@ -83,7 +82,7 @@ func TestBatchConsoleURLFromProto(t *testing.T) {
|
||||
batchPb := &dataprocpb.Batch{
|
||||
Name: "projects/my-project/locations/us-central1/batches/my-batch",
|
||||
}
|
||||
got, err := serverlessspark.BatchConsoleURLFromProto(batchPb)
|
||||
got, err := BatchConsoleURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
t.Fatalf("BatchConsoleURLFromProto() error = %v", err)
|
||||
}
|
||||
@@ -101,7 +100,7 @@ func TestBatchLogsURLFromProto(t *testing.T) {
|
||||
CreateTime: timestamppb.New(createTime),
|
||||
StateTime: timestamppb.New(stateTime),
|
||||
}
|
||||
got, err := serverlessspark.BatchLogsURLFromProto(batchPb)
|
||||
got, err := BatchLogsURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
t.Fatalf("BatchLogsURLFromProto() error = %v", err)
|
||||
}
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
|
||||
dataprocpb "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
@@ -35,7 +36,9 @@ func unmarshalProto(data any, m proto.Message) error {
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
CreateBatch(context.Context, *dataprocpb.Batch) (map[string]any, error)
|
||||
GetBatchControllerClient() *dataproc.BatchControllerClient
|
||||
GetProject() string
|
||||
GetLocation() string
|
||||
}
|
||||
|
||||
// Config is a common config that can be used with any type of create batch tool. However, each tool
|
||||
|
||||
@@ -16,19 +16,23 @@ package createbatch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
dataprocpb "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
type BatchBuilder interface {
|
||||
Parameters() parameters.Parameters
|
||||
BuildBatch(parameters.ParamValues) (*dataprocpb.Batch, error)
|
||||
BuildBatch(params parameters.ParamValues) (*dataprocpb.Batch, error)
|
||||
}
|
||||
|
||||
func NewTool(cfg Config, originalCfg tools.ToolConfig, srcs map[string]sources.Source, builder BatchBuilder) (*Tool, error) {
|
||||
@@ -70,6 +74,7 @@ func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, par
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := source.GetBatchControllerClient()
|
||||
|
||||
batch, err := t.Builder.BuildBatch(params)
|
||||
if err != nil {
|
||||
@@ -92,7 +97,46 @@ func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, par
|
||||
}
|
||||
batch.RuntimeConfig.Version = version
|
||||
}
|
||||
return source.CreateBatch(ctx, batch)
|
||||
|
||||
req := &dataprocpb.CreateBatchRequest{
|
||||
Parent: fmt.Sprintf("projects/%s/locations/%s", source.GetProject(), source.GetLocation()),
|
||||
Batch: batch,
|
||||
}
|
||||
|
||||
op, err := client.CreateBatch(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create batch: %w", err)
|
||||
}
|
||||
|
||||
meta, err := op.Metadata()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get create batch op metadata: %w", err)
|
||||
}
|
||||
|
||||
jsonBytes, err := protojson.Marshal(meta)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal create batch op metadata to JSON: %w", err)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(jsonBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal create batch op metadata JSON: %w", err)
|
||||
}
|
||||
|
||||
projectID, location, batchID, err := common.ExtractBatchDetails(meta.GetBatch())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error extracting batch details from name %q: %v", meta.GetBatch(), err)
|
||||
}
|
||||
consoleUrl := common.BatchConsoleURL(projectID, location, batchID)
|
||||
logsUrl := common.BatchLogsURL(projectID, location, batchID, meta.GetCreateTime().AsTime(), time.Time{})
|
||||
|
||||
wrappedResult := map[string]any{
|
||||
"opMetadata": meta,
|
||||
"consoleUrl": consoleUrl,
|
||||
"logsUrl": logsUrl,
|
||||
}
|
||||
|
||||
return wrappedResult, nil
|
||||
}
|
||||
|
||||
func (t *Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -19,7 +19,8 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
|
||||
longrunning "cloud.google.com/go/longrunning/autogen"
|
||||
"cloud.google.com/go/longrunning/autogen/longrunningpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -44,8 +45,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetBatchControllerClient() *dataproc.BatchControllerClient
|
||||
CancelOperation(context.Context, string) (any, error)
|
||||
GetOperationsClient(context.Context) (*longrunning.OperationsClient, error)
|
||||
GetProject() string
|
||||
GetLocation() string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -104,15 +106,32 @@ func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, par
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := source.GetOperationsClient(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get operations client: %w", err)
|
||||
}
|
||||
|
||||
paramMap := params.AsMap()
|
||||
operation, ok := paramMap["operation"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing required parameter: operation")
|
||||
}
|
||||
|
||||
if strings.Contains(operation, "/") {
|
||||
return nil, fmt.Errorf("operation must be a short operation name without '/': %s", operation)
|
||||
}
|
||||
return source.CancelOperation(ctx, operation)
|
||||
|
||||
req := &longrunningpb.CancelOperationRequest{
|
||||
Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", source.GetProject(), source.GetLocation(), operation),
|
||||
}
|
||||
|
||||
err = client.CancelOperation(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to cancel operation: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Cancelled [%s].", operation), nil
|
||||
}
|
||||
|
||||
func (t *Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -16,15 +16,19 @@ package serverlesssparkgetbatch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
|
||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
const kind = "serverless-spark-get-batch"
|
||||
@@ -45,7 +49,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
GetBatchControllerClient() *dataproc.BatchControllerClient
|
||||
GetBatch(context.Context, string) (map[string]any, error)
|
||||
GetProject() string
|
||||
GetLocation() string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -104,15 +109,54 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := source.GetBatchControllerClient()
|
||||
|
||||
paramMap := params.AsMap()
|
||||
name, ok := paramMap["name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing required parameter: name")
|
||||
}
|
||||
|
||||
if strings.Contains(name, "/") {
|
||||
return nil, fmt.Errorf("name must be a short batch name without '/': %s", name)
|
||||
}
|
||||
return source.GetBatch(ctx, name)
|
||||
|
||||
req := &dataprocpb.GetBatchRequest{
|
||||
Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", source.GetProject(), source.GetLocation(), name),
|
||||
}
|
||||
|
||||
batchPb, err := client.GetBatch(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get batch: %w", err)
|
||||
}
|
||||
|
||||
jsonBytes, err := protojson.Marshal(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal batch to JSON: %w", err)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(jsonBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal batch JSON: %w", err)
|
||||
}
|
||||
|
||||
consoleUrl, err := common.BatchConsoleURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating console url: %v", err)
|
||||
}
|
||||
logsUrl, err := common.BatchLogsURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating logs url: %v", err)
|
||||
}
|
||||
|
||||
wrappedResult := map[string]any{
|
||||
"consoleUrl": consoleUrl,
|
||||
"logsUrl": logsUrl,
|
||||
"batch": result,
|
||||
}
|
||||
|
||||
return wrappedResult, nil
|
||||
}
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
return parameters.ParseParams(t.Parameters, data, claims)
|
||||
|
||||
@@ -17,13 +17,17 @@ package serverlesssparklistbatches
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
|
||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/iterator"
|
||||
)
|
||||
|
||||
const kind = "serverless-spark-list-batches"
|
||||
@@ -44,7 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
GetBatchControllerClient() *dataproc.BatchControllerClient
|
||||
ListBatches(context.Context, *int, string, string) (any, error)
|
||||
GetProject() string
|
||||
GetLocation() string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -99,24 +104,95 @@ type Tool struct {
|
||||
Parameters parameters.Parameters
|
||||
}
|
||||
|
||||
// ListBatchesResponse is the response from the list batches API.
|
||||
type ListBatchesResponse struct {
|
||||
Batches []Batch `json:"batches"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
}
|
||||
|
||||
// Batch represents a single batch job.
|
||||
type Batch struct {
|
||||
Name string `json:"name"`
|
||||
UUID string `json:"uuid"`
|
||||
State string `json:"state"`
|
||||
Creator string `json:"creator"`
|
||||
CreateTime string `json:"createTime"`
|
||||
Operation string `json:"operation"`
|
||||
ConsoleURL string `json:"consoleUrl"`
|
||||
LogsURL string `json:"logsUrl"`
|
||||
}
|
||||
|
||||
// Invoke executes the tool's operation.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
paramMap := params.AsMap()
|
||||
var pageSize *int
|
||||
if ps, ok := paramMap["pageSize"]; ok && ps != nil {
|
||||
pageSizeV := ps.(int)
|
||||
if pageSizeV <= 0 {
|
||||
return nil, fmt.Errorf("pageSize must be positive: %d", pageSizeV)
|
||||
}
|
||||
pageSize = &pageSizeV
|
||||
|
||||
client := source.GetBatchControllerClient()
|
||||
|
||||
parent := fmt.Sprintf("projects/%s/locations/%s", source.GetProject(), source.GetLocation())
|
||||
req := &dataprocpb.ListBatchesRequest{
|
||||
Parent: parent,
|
||||
OrderBy: "create_time desc",
|
||||
}
|
||||
pt, _ := paramMap["pageToken"].(string)
|
||||
filter, _ := paramMap["filter"].(string)
|
||||
return source.ListBatches(ctx, pageSize, pt, filter)
|
||||
|
||||
paramMap := params.AsMap()
|
||||
if ps, ok := paramMap["pageSize"]; ok && ps != nil {
|
||||
req.PageSize = int32(ps.(int))
|
||||
if (req.PageSize) <= 0 {
|
||||
return nil, fmt.Errorf("pageSize must be positive: %d", req.PageSize)
|
||||
}
|
||||
}
|
||||
if pt, ok := paramMap["pageToken"]; ok && pt != nil {
|
||||
req.PageToken = pt.(string)
|
||||
}
|
||||
if filter, ok := paramMap["filter"]; ok && filter != nil {
|
||||
req.Filter = filter.(string)
|
||||
}
|
||||
|
||||
it := client.ListBatches(ctx, req)
|
||||
pager := iterator.NewPager(it, int(req.PageSize), req.PageToken)
|
||||
|
||||
var batchPbs []*dataprocpb.Batch
|
||||
nextPageToken, err := pager.NextPage(&batchPbs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list batches: %w", err)
|
||||
}
|
||||
|
||||
batches, err := ToBatches(batchPbs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ListBatchesResponse{Batches: batches, NextPageToken: nextPageToken}, nil
|
||||
}
|
||||
|
||||
// ToBatches converts a slice of protobuf Batch messages to a slice of Batch structs.
|
||||
func ToBatches(batchPbs []*dataprocpb.Batch) ([]Batch, error) {
|
||||
batches := make([]Batch, 0, len(batchPbs))
|
||||
for _, batchPb := range batchPbs {
|
||||
consoleUrl, err := common.BatchConsoleURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating console url: %v", err)
|
||||
}
|
||||
logsUrl, err := common.BatchLogsURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating logs url: %v", err)
|
||||
}
|
||||
batch := Batch{
|
||||
Name: batchPb.Name,
|
||||
UUID: batchPb.Uuid,
|
||||
State: batchPb.State.Enum().String(),
|
||||
Creator: batchPb.Creator,
|
||||
CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339),
|
||||
Operation: batchPb.Operation,
|
||||
ConsoleURL: consoleUrl,
|
||||
LogsURL: logsUrl,
|
||||
}
|
||||
batches = append(batches, batch)
|
||||
}
|
||||
return batches, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -33,8 +33,8 @@ import (
|
||||
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
|
||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
"google.golang.org/api/iterator"
|
||||
"google.golang.org/api/option"
|
||||
@@ -676,7 +676,7 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct
|
||||
filter string
|
||||
pageSize int
|
||||
numPages int
|
||||
want []serverlessspark.Batch
|
||||
want []serverlesssparklistbatches.Batch
|
||||
}{
|
||||
{name: "one page", pageSize: 2, numPages: 1, want: batch2},
|
||||
{name: "two pages", pageSize: 1, numPages: 2, want: batch2},
|
||||
@@ -701,7 +701,7 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var actual []serverlessspark.Batch
|
||||
var actual []serverlesssparklistbatches.Batch
|
||||
var pageToken string
|
||||
for i := 0; i < tc.numPages; i++ {
|
||||
request := map[string]any{
|
||||
@@ -733,7 +733,7 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
var listResponse serverlessspark.ListBatchesResponse
|
||||
var listResponse serverlesssparklistbatches.ListBatchesResponse
|
||||
if err := json.Unmarshal([]byte(result), &listResponse); err != nil {
|
||||
t.Fatalf("error unmarshalling result: %s", err)
|
||||
}
|
||||
@@ -759,7 +759,7 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct
|
||||
}
|
||||
}
|
||||
|
||||
func listBatchesRpc(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, filter string, n int, exact bool) []serverlessspark.Batch {
|
||||
func listBatchesRpc(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, filter string, n int, exact bool) []serverlesssparklistbatches.Batch {
|
||||
parent := fmt.Sprintf("projects/%s/locations/%s", serverlessSparkProject, serverlessSparkLocation)
|
||||
req := &dataprocpb.ListBatchesRequest{
|
||||
Parent: parent,
|
||||
@@ -783,7 +783,7 @@ func listBatchesRpc(t *testing.T, client *dataproc.BatchControllerClient, ctx co
|
||||
if !exact && (len(batchPbs) == 0 || len(batchPbs) > n) {
|
||||
t.Fatalf("expected between 1 and %d batches, got %d", n, len(batchPbs))
|
||||
}
|
||||
batches, err := serverlessspark.ToBatches(batchPbs)
|
||||
batches, err := serverlesssparklistbatches.ToBatches(batchPbs)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to convert batches to JSON: %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user