mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-08 15:14:00 -05:00
feat(neo4j): Add dry_run parameter to validate Cypher queries (#1769)
This pull request adds support for a new `dry_run` mode to the Neo4j Cypher execution tool, allowing users to validate queries and view execution plans without running them. It also sets a custom user agent for Neo4j connections and improves error handling and documentation. The most important changes are grouped below. ### New dry run feature for Cypher execution * Added an optional `dry_run` boolean parameter to the `neo4j-execute-cypher` tool, allowing users to validate Cypher queries and receive execution plan details without running the query. The tool now prepends `EXPLAIN` to the query when `dry_run` is true and returns a structured summary of the execution plan. [[1]](diffhunk://#diff-de7fdd7e68c95ea9813c704a89fffb8fd6de34e81b43a484623fdff7683e18f3L87-R93) [[2]](diffhunk://#diff-de7fdd7e68c95ea9813c704a89fffb8fd6de34e81b43a484623fdff7683e18f3R155-R188) [[3]](diffhunk://#diff-de7fdd7e68c95ea9813c704a89fffb8fd6de34e81b43a484623fdff7683e18f3R219-R236) [[4]](diffhunk://#diff-1dca93fc9450e9b9ea64bc1ae02774c3198ea6f8310b2437815bd1a5eae11e79L30-R32) * Updated integration tests to cover the new `dry_run` functionality, including successful dry runs, error handling for invalid syntax, and enforcement of read-only mode. [[1]](diffhunk://#diff-b07de4a304bc72964b5de9481cbc6aec6cf9bb9dabd903a837eb8974e7100a90R163-R169) [[2]](diffhunk://#diff-b07de4a304bc72964b5de9481cbc6aec6cf9bb9dabd903a837eb8974e7100a90R250-R291) ### Improved error handling * Enhanced error messages for parameter casting in the tool's `Invoke` method to clarify issues with input parameters. ### Neo4j driver configuration * Set a custom user agent (`genai-toolbox/neo4j-source`) for Neo4j driver connections to help identify requests from this tool. [[1]](diffhunk://#diff-3f0444add0913f1722d678118ffedc70039cca3603f31c9927c06be5e00ffb29R24-R29) [[2]](diffhunk://#diff-3f0444add0913f1722d678118ffedc70039cca3603f31c9927c06be5e00ffb29L109-R113) ### Documentation updates * Updated the documentation to describe the new `dry_run` parameter and its usage for query validation. --------- Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
This commit is contained in:
@@ -27,8 +27,9 @@ Cypher](https://neo4j.com/docs/cypher-manual/current/queries/) syntax and
|
||||
supports all Cypher features, including pattern matching, filtering, and
|
||||
aggregation.
|
||||
|
||||
`neo4j-execute-cypher` takes one input parameter `cypher` and run the cypher
|
||||
query against the `source`.
|
||||
`neo4j-execute-cypher` takes a required input parameter `cypher` and run the cypher
|
||||
query against the `source`. It also supports an optional `dry_run`
|
||||
parameter to validate a query without executing it.
|
||||
|
||||
> **Note:** This tool is intended for developer assistant workflows with
|
||||
> human-in-the-loop and shouldn't be used for production agents.
|
||||
|
||||
@@ -20,7 +20,9 @@ import (
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
|
||||
neo4jconf "github.com/neo4j/neo4j-go-driver/v5/neo4j/config"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
@@ -106,7 +108,13 @@ func initNeo4jDriver(ctx context.Context, tracer trace.Tracer, uri, user, passwo
|
||||
defer span.End()
|
||||
|
||||
auth := neo4j.BasicAuth(user, password, "")
|
||||
driver, err := neo4j.NewDriverWithContext(uri, auth)
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
driver, err := neo4j.NewDriverWithContext(uri, auth, func(config *neo4jconf.Config) {
|
||||
config.UserAgent = userAgent
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create connection driver: %w", err)
|
||||
}
|
||||
|
||||
@@ -84,7 +84,13 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
}
|
||||
|
||||
cypherParameter := tools.NewStringParameter("cypher", "The cypher to execute.")
|
||||
parameters := tools.Parameters{cypherParameter}
|
||||
dryRunParameter := tools.NewBooleanParameterWithDefault(
|
||||
"dry_run",
|
||||
false,
|
||||
"If set to true, the query will be validated and information about the execution "+
|
||||
"will be returned without running the query. Defaults to false.",
|
||||
)
|
||||
parameters := tools.Parameters{cypherParameter, dryRunParameter}
|
||||
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters)
|
||||
|
||||
@@ -124,13 +130,18 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
paramsMap := params.AsMap()
|
||||
cypherStr, ok := paramsMap["cypher"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to get cast %s", paramsMap["cypher"])
|
||||
return nil, fmt.Errorf("unable to cast cypher parameter %s", paramsMap["cypher"])
|
||||
}
|
||||
|
||||
if cypherStr == "" {
|
||||
return nil, fmt.Errorf("parameter 'cypher' must be a non-empty string")
|
||||
}
|
||||
|
||||
dryRun, ok := paramsMap["dry_run"].(bool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"])
|
||||
}
|
||||
|
||||
// validate the cypher query before executing
|
||||
cf := t.classifier.Classify(cypherStr)
|
||||
if cf.Error != nil {
|
||||
@@ -141,6 +152,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("this tool is read-only and cannot execute write queries")
|
||||
}
|
||||
|
||||
if dryRun {
|
||||
// Add EXPLAIN to the beginning of the query to validate it without executing
|
||||
cypherStr = "EXPLAIN " + cypherStr
|
||||
}
|
||||
|
||||
config := neo4j.ExecuteQueryWithDatabase(t.Database)
|
||||
results, err := neo4j.ExecuteQuery(ctx, t.Driver, cypherStr, nil,
|
||||
neo4j.EagerResultTransformer, config)
|
||||
@@ -148,9 +164,28 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
// If dry run, return the summary information only
|
||||
if dryRun {
|
||||
summary := results.Summary
|
||||
plan := summary.Plan()
|
||||
execPlan := map[string]any{
|
||||
"queryType": cf.Type.String(),
|
||||
"statementType": summary.StatementType(),
|
||||
"operator": plan.Operator(),
|
||||
"arguments": plan.Arguments(),
|
||||
"identifiers": plan.Identifiers(),
|
||||
"childrenCount": len(plan.Children()),
|
||||
}
|
||||
if len(plan.Children()) > 0 {
|
||||
execPlan["children"] = addPlanChildren(plan)
|
||||
}
|
||||
return []map[string]any{execPlan}, nil
|
||||
}
|
||||
|
||||
var out []any
|
||||
keys := results.Keys
|
||||
records := results.Records
|
||||
|
||||
for _, record := range records {
|
||||
vMap := make(map[string]any)
|
||||
for col, value := range record.Values {
|
||||
@@ -181,3 +216,21 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Recursive function to add plan children
|
||||
func addPlanChildren(p neo4j.Plan) []map[string]any {
|
||||
var children []map[string]any
|
||||
for _, child := range p.Children() {
|
||||
childMap := map[string]any{
|
||||
"operator": child.Operator(),
|
||||
"arguments": child.Arguments(),
|
||||
"identifiers": child.Identifiers(),
|
||||
"children_count": len(child.Children()),
|
||||
}
|
||||
if len(child.Children()) > 0 {
|
||||
childMap["children"] = addPlanChildren(child)
|
||||
}
|
||||
children = append(children, childMap)
|
||||
}
|
||||
return children
|
||||
}
|
||||
|
||||
@@ -160,6 +160,13 @@ func TestNeo4jToolEndpoints(t *testing.T) {
|
||||
"description": "The cypher to execute.",
|
||||
"authSources": []any{},
|
||||
},
|
||||
map[string]any{
|
||||
"name": "dry_run",
|
||||
"type": "boolean",
|
||||
"required": false,
|
||||
"description": "If set to true, the query will be validated and information about the execution will be returned without running the query. Defaults to false.",
|
||||
"authSources": []any{},
|
||||
},
|
||||
},
|
||||
"authRequired": []any{},
|
||||
},
|
||||
@@ -240,6 +247,51 @@ func TestNeo4jToolEndpoints(t *testing.T) {
|
||||
want: "[{\"a\":1}]",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invoke my-simple-execute-cypher-tool with dry_run",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-execute-cypher-tool/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"cypher": "MATCH (n:Test) RETURN n", "dry_run": true}`)),
|
||||
wantStatus: http.StatusOK,
|
||||
validateFunc: func(t *testing.T, body string) {
|
||||
var result []map[string]any
|
||||
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||
t.Fatalf("failed to unmarshal dry_run result: %v", err)
|
||||
}
|
||||
if len(result) == 0 {
|
||||
t.Fatalf("expected a query plan, but got an empty result")
|
||||
}
|
||||
|
||||
operatorValue, ok := result[0]["operator"]
|
||||
if !ok {
|
||||
t.Fatalf("expected key 'Operator' not found in dry_run response: %s", body)
|
||||
}
|
||||
|
||||
operatorStr, ok := operatorValue.(string)
|
||||
if !ok {
|
||||
t.Fatalf("expected 'Operator' to be a string, but got %T", operatorValue)
|
||||
}
|
||||
|
||||
if operatorStr != "ProduceResults@neo4j" {
|
||||
t.Errorf("unexpected operator: got %q, want %q", operatorStr, "ProduceResults@neo4j")
|
||||
}
|
||||
|
||||
childrenCount, ok := result[0]["childrenCount"]
|
||||
if !ok {
|
||||
t.Fatalf("expected key 'ChildrenCount' not found in dry_run response: %s", body)
|
||||
}
|
||||
|
||||
if childrenCount.(float64) != 1 {
|
||||
t.Errorf("unexpected children count: got %v, want %d", childrenCount, 1)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invoke my-simple-execute-cypher-tool with dry_run and invalid syntax",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-execute-cypher-tool/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"cypher": "RTN 1", "dry_run": true}`)),
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErrorSubstring: "unable to execute query",
|
||||
},
|
||||
{
|
||||
name: "invoke readonly tool with write query",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-readonly-execute-cypher-tool/invoke",
|
||||
@@ -247,6 +299,13 @@ func TestNeo4jToolEndpoints(t *testing.T) {
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErrorSubstring: "this tool is read-only and cannot execute write queries",
|
||||
},
|
||||
{
|
||||
name: "invoke readonly tool with write query and dry_run",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-readonly-execute-cypher-tool/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"cypher": "CREATE (n:TestNode)", "dry_run": true}`)),
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErrorSubstring: "this tool is read-only and cannot execute write queries",
|
||||
},
|
||||
{
|
||||
name: "invoke my-schema-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-schema-tool/invoke",
|
||||
|
||||
Reference in New Issue
Block a user