diff --git a/docs/en/resources/tools/neo4j/neo4j-execute-cypher.md b/docs/en/resources/tools/neo4j/neo4j-execute-cypher.md index c03b843761..f544e1c25d 100644 --- a/docs/en/resources/tools/neo4j/neo4j-execute-cypher.md +++ b/docs/en/resources/tools/neo4j/neo4j-execute-cypher.md @@ -27,8 +27,9 @@ Cypher](https://neo4j.com/docs/cypher-manual/current/queries/) syntax and supports all Cypher features, including pattern matching, filtering, and aggregation. -`neo4j-execute-cypher` takes one input parameter `cypher` and run the cypher -query against the `source`. +`neo4j-execute-cypher` takes a required input parameter `cypher` and run the cypher +query against the `source`. It also supports an optional `dry_run` +parameter to validate a query without executing it. > **Note:** This tool is intended for developer assistant workflows with > human-in-the-loop and shouldn't be used for production agents. diff --git a/internal/sources/neo4j/neo4j.go b/internal/sources/neo4j/neo4j.go index 9c1b9c76f6..2262a54bc9 100644 --- a/internal/sources/neo4j/neo4j.go +++ b/internal/sources/neo4j/neo4j.go @@ -20,7 +20,9 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/util" "github.com/neo4j/neo4j-go-driver/v5/neo4j" + neo4jconf "github.com/neo4j/neo4j-go-driver/v5/neo4j/config" "go.opentelemetry.io/otel/trace" ) @@ -106,7 +108,13 @@ func initNeo4jDriver(ctx context.Context, tracer trace.Tracer, uri, user, passwo defer span.End() auth := neo4j.BasicAuth(user, password, "") - driver, err := neo4j.NewDriverWithContext(uri, auth) + userAgent, err := util.UserAgentFromContext(ctx) + if err != nil { + return nil, err + } + driver, err := neo4j.NewDriverWithContext(uri, auth, func(config *neo4jconf.Config) { + config.UserAgent = userAgent + }) if err != nil { return nil, fmt.Errorf("unable to create connection driver: %w", err) } diff --git a/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go b/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go index ee26330854..33c474d509 100644 --- a/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go +++ b/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go @@ -84,7 +84,13 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } cypherParameter := tools.NewStringParameter("cypher", "The cypher to execute.") - parameters := tools.Parameters{cypherParameter} + dryRunParameter := tools.NewBooleanParameterWithDefault( + "dry_run", + false, + "If set to true, the query will be validated and information about the execution "+ + "will be returned without running the query. Defaults to false.", + ) + parameters := tools.Parameters{cypherParameter, dryRunParameter} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters) @@ -124,13 +130,18 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken paramsMap := params.AsMap() cypherStr, ok := paramsMap["cypher"].(string) if !ok { - return nil, fmt.Errorf("unable to get cast %s", paramsMap["cypher"]) + return nil, fmt.Errorf("unable to cast cypher parameter %s", paramsMap["cypher"]) } if cypherStr == "" { return nil, fmt.Errorf("parameter 'cypher' must be a non-empty string") } + dryRun, ok := paramsMap["dry_run"].(bool) + if !ok { + return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"]) + } + // validate the cypher query before executing cf := t.classifier.Classify(cypherStr) if cf.Error != nil { @@ -141,6 +152,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken return nil, fmt.Errorf("this tool is read-only and cannot execute write queries") } + if dryRun { + // Add EXPLAIN to the beginning of the query to validate it without executing + cypherStr = "EXPLAIN " + cypherStr + } + config := neo4j.ExecuteQueryWithDatabase(t.Database) results, err := neo4j.ExecuteQuery(ctx, t.Driver, cypherStr, nil, neo4j.EagerResultTransformer, config) @@ -148,9 +164,28 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken return nil, fmt.Errorf("unable to execute query: %w", err) } + // If dry run, return the summary information only + if dryRun { + summary := results.Summary + plan := summary.Plan() + execPlan := map[string]any{ + "queryType": cf.Type.String(), + "statementType": summary.StatementType(), + "operator": plan.Operator(), + "arguments": plan.Arguments(), + "identifiers": plan.Identifiers(), + "childrenCount": len(plan.Children()), + } + if len(plan.Children()) > 0 { + execPlan["children"] = addPlanChildren(plan) + } + return []map[string]any{execPlan}, nil + } + var out []any keys := results.Keys records := results.Records + for _, record := range records { vMap := make(map[string]any) for col, value := range record.Values { @@ -181,3 +216,21 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { func (t Tool) RequiresClientAuthorization() bool { return false } + +// Recursive function to add plan children +func addPlanChildren(p neo4j.Plan) []map[string]any { + var children []map[string]any + for _, child := range p.Children() { + childMap := map[string]any{ + "operator": child.Operator(), + "arguments": child.Arguments(), + "identifiers": child.Identifiers(), + "children_count": len(child.Children()), + } + if len(child.Children()) > 0 { + childMap["children"] = addPlanChildren(child) + } + children = append(children, childMap) + } + return children +} diff --git a/tests/neo4j/neo4j_integration_test.go b/tests/neo4j/neo4j_integration_test.go index 7d60083c66..816edba6e9 100644 --- a/tests/neo4j/neo4j_integration_test.go +++ b/tests/neo4j/neo4j_integration_test.go @@ -160,6 +160,13 @@ func TestNeo4jToolEndpoints(t *testing.T) { "description": "The cypher to execute.", "authSources": []any{}, }, + map[string]any{ + "name": "dry_run", + "type": "boolean", + "required": false, + "description": "If set to true, the query will be validated and information about the execution will be returned without running the query. Defaults to false.", + "authSources": []any{}, + }, }, "authRequired": []any{}, }, @@ -240,6 +247,51 @@ func TestNeo4jToolEndpoints(t *testing.T) { want: "[{\"a\":1}]", wantStatus: http.StatusOK, }, + { + name: "invoke my-simple-execute-cypher-tool with dry_run", + api: "http://127.0.0.1:5000/api/tool/my-simple-execute-cypher-tool/invoke", + requestBody: bytes.NewBuffer([]byte(`{"cypher": "MATCH (n:Test) RETURN n", "dry_run": true}`)), + wantStatus: http.StatusOK, + validateFunc: func(t *testing.T, body string) { + var result []map[string]any + if err := json.Unmarshal([]byte(body), &result); err != nil { + t.Fatalf("failed to unmarshal dry_run result: %v", err) + } + if len(result) == 0 { + t.Fatalf("expected a query plan, but got an empty result") + } + + operatorValue, ok := result[0]["operator"] + if !ok { + t.Fatalf("expected key 'Operator' not found in dry_run response: %s", body) + } + + operatorStr, ok := operatorValue.(string) + if !ok { + t.Fatalf("expected 'Operator' to be a string, but got %T", operatorValue) + } + + if operatorStr != "ProduceResults@neo4j" { + t.Errorf("unexpected operator: got %q, want %q", operatorStr, "ProduceResults@neo4j") + } + + childrenCount, ok := result[0]["childrenCount"] + if !ok { + t.Fatalf("expected key 'ChildrenCount' not found in dry_run response: %s", body) + } + + if childrenCount.(float64) != 1 { + t.Errorf("unexpected children count: got %v, want %d", childrenCount, 1) + } + }, + }, + { + name: "invoke my-simple-execute-cypher-tool with dry_run and invalid syntax", + api: "http://127.0.0.1:5000/api/tool/my-simple-execute-cypher-tool/invoke", + requestBody: bytes.NewBuffer([]byte(`{"cypher": "RTN 1", "dry_run": true}`)), + wantStatus: http.StatusBadRequest, + wantErrorSubstring: "unable to execute query", + }, { name: "invoke readonly tool with write query", api: "http://127.0.0.1:5000/api/tool/my-readonly-execute-cypher-tool/invoke", @@ -247,6 +299,13 @@ func TestNeo4jToolEndpoints(t *testing.T) { wantStatus: http.StatusBadRequest, wantErrorSubstring: "this tool is read-only and cannot execute write queries", }, + { + name: "invoke readonly tool with write query and dry_run", + api: "http://127.0.0.1:5000/api/tool/my-readonly-execute-cypher-tool/invoke", + requestBody: bytes.NewBuffer([]byte(`{"cypher": "CREATE (n:TestNode)", "dry_run": true}`)), + wantStatus: http.StatusBadRequest, + wantErrorSubstring: "this tool is read-only and cannot execute write queries", + }, { name: "invoke my-schema-tool", api: "http://127.0.0.1:5000/api/tool/my-schema-tool/invoke",