feat: Neo4j tools enhancements - neo4j-execute-cypher (#946)

This pull request introduces a new tool for executing arbitrary Cypher
queries against a Neo4j database (`neo4j-execute-cypher`) and implements
robust query classification functionality to distinguish between read
and write operations. The changes include updates to documentation, the
addition of a query classifier, and comprehensive test coverage for the
classifier.

### Addition of `neo4j-execute-cypher` tool:

- **Documentation**: Added a new markdown file `neo4j-execute-cypher.md`
that explains the tool's functionality, usage, and configuration
options, including the ability to enforce read-only mode for security.
- **Import statement**: Registered the new tool in the `cmd/root.go`
file to make it available in the toolbox.

### Query classification functionality:

- **Query classifier implementation**: Added `QueryClassifier` in
`classifier.go`, which classifies Cypher queries into read or write
operations based on keywords, procedures, and subquery analysis. It
supports handling edge cases like nested subqueries, multi-word
keywords, and invalid syntax.
- **Test coverage**: Created extensive tests in `classifier_test.go` to
validate the classifier's behavior across various query types, including
abuse cases, subqueries, and procedure calls. Tests ensure the
classifier is robust and does not panic on malformed queries.

---------

Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
This commit is contained in:
nester-neo4j
2025-07-23 13:33:44 -04:00
committed by GitHub
parent 15417d4e0c
commit 81d05053b2
7 changed files with 1212 additions and 17 deletions

View File

@@ -64,6 +64,7 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlexecutesql"
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlsql"
_ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jcypher"
_ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jexecutecypher"
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresexecutesql"
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql"
_ "github.com/googleapis/genai-toolbox/internal/tools/redis"

View File

@@ -0,0 +1,53 @@
---
title: "neo4j-execute-cypher"
type: docs
weight: 1
description: >
A "neo4j-execute-cypher" tool executes any arbitrary Cypher statement against a Neo4j
database.
aliases:
- /resources/tools/neo4j-execute-cypher
---
## About
A `neo4j-execute-cypher` tool executes an arbitrary Cypher query provided as a string parameter against a Neo4j database. It's designed to be a flexible tool for interacting with the database when a pre-defined query is not sufficient. This tool is compatible with any of the following sources:
- [neo4j](../sources/neo4j.md)
For security, the tool can be configured to be read-only. If the `readOnly` flag is set to `true`, the tool will analyze the incoming Cypher query and reject any write operations (like `CREATE`, `MERGE`, `DELETE`, etc.) before execution.
The Cypher query uses standard [Neo4j Cypher](https://neo4j.com/docs/cypher-manual/current/queries/) syntax and supports all Cypher features, including pattern matching, filtering, and aggregation.
`neo4j-execute-cypher` takes one input parameter `cypher` and run the cypher query against the `source`.
> **Note:** This tool is intended for developer assistant workflows with
> human-in-the-loop and shouldn't be used for production agents.
## Example
```yaml
tools:
query_neo4j:
kind: neo4j-execute-cypher
source: my-neo4j-prod-db
readOnly: true
description: |
Use this tool to execute a Cypher query against the production database.
Only read-only queries are allowed.
Takes a single 'cypher' parameter containing the full query string.
Example:
{{
"cypher": "MATCH (m:Movie {title: 'The Matrix'}) RETURN m.released"
}}
```
## Reference
| **field** | **type** | **required** | **description** |
|-------------|:------------------------------------------:|:------------:|-------------------------------------------------------------------------------------------------|
| kind | string | true | Must be "neo4j-cypher". |
| source | string | true | Name of the source the Cypher query should execute on. |
| description | string | true | Description of the tool that is passed to the LLM. |
| readOnly | boolean | false | If set to `true`, the tool will reject any write operations in the Cypher query. Default is `false`. |

View File

@@ -0,0 +1,434 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
Package classifier provides tools to classify Cypher queries as either read-only or write operations.
It uses a keyword-based and procedure-based approach to determine the query's nature.
The main entry point is the `Classify` method on a `QueryClassifier` object. The classifier
is designed to be conservative, defaulting to classifying unknown procedures as write
operations to ensure safety in read-only environments.
It can handle:
- Standard Cypher keywords (MATCH, CREATE, MERGE, etc.).
- Multi-word keywords (DETACH DELETE, ORDER BY).
- Comments and string literals, which are ignored during classification.
- Procedure calls (CALL db.labels), with predefined lists of known read/write procedures.
- Subqueries (CALL { ... }), checking for write operations within the subquery block.
*/
package classifier
import (
"regexp"
"sort"
"strings"
)
// QueryType represents the classification of a Cypher query as either read or write.
type QueryType int
const (
// ReadQuery indicates a query that only reads data.
ReadQuery QueryType = iota
// WriteQuery indicates a query that modifies data.
WriteQuery
)
// String provides a human-readable representation of the QueryType.
func (qt QueryType) String() string {
if qt == ReadQuery {
return "READ"
}
return "WRITE"
}
// QueryClassification represents the detailed result of a query classification.
type QueryClassification struct {
// Type is the overall classification of the query (READ or WRITE).
Type QueryType
// Confidence is a score from 0.0 to 1.0 indicating the classifier's certainty.
// 1.0 is fully confident. Lower scores may be assigned for ambiguous cases,
// like unknown procedures.
Confidence float64
// WriteTokens is a list of keywords or procedures found that indicate a write operation.
WriteTokens []string
// ReadTokens is a list of keywords or procedures found that indicate a read operation.
ReadTokens []string
// HasSubquery is true if the query contains a `CALL { ... }` block.
HasSubquery bool
// Error holds any error that occurred during classification, though this is not
// currently used in the implementation.
Error error
}
// QueryClassifier contains the logic and data for classifying Cypher queries.
// It should be instantiated via the NewQueryClassifier() function.
type QueryClassifier struct {
writeKeywords map[string]struct{}
readKeywords map[string]struct{}
// writeProcedures is a map of known write procedure prefixes for quick lookup.
writeProcedures map[string]struct{}
// readProcedures is a map of known read procedure prefixes for quick lookup.
readProcedures map[string]struct{}
multiWordWriteKeywords []string
multiWordReadKeywords []string
commentPattern *regexp.Regexp
stringLiteralPattern *regexp.Regexp
procedureCallPattern *regexp.Regexp
subqueryPattern *regexp.Regexp
whitespacePattern *regexp.Regexp
tokenSplitPattern *regexp.Regexp
}
// NewQueryClassifier creates and initializes a new QueryClassifier instance.
// It pre-compiles regular expressions and populates the internal lists of
// known Cypher keywords and procedures.
func NewQueryClassifier() *QueryClassifier {
c := &QueryClassifier{
writeKeywords: make(map[string]struct{}),
readKeywords: make(map[string]struct{}),
writeProcedures: make(map[string]struct{}),
readProcedures: make(map[string]struct{}),
commentPattern: regexp.MustCompile(`(?m)//.*?$|/\*[\s\S]*?\*/`),
stringLiteralPattern: regexp.MustCompile(`'[^']*'|"[^"]*"`),
procedureCallPattern: regexp.MustCompile(`(?i)\bCALL\s+([a-zA-Z0-9_.]+)`),
subqueryPattern: regexp.MustCompile(`(?i)\bCALL\s*\{`),
whitespacePattern: regexp.MustCompile(`\s+`),
tokenSplitPattern: regexp.MustCompile(`[\s,(){}[\]]+`),
}
// Lists of known keywords that perform write operations.
writeKeywordsList := []string{
"CREATE", "MERGE", "DELETE", "DETACH DELETE", "SET", "REMOVE", "FOREACH",
"CREATE INDEX", "DROP INDEX", "CREATE CONSTRAINT", "DROP CONSTRAINT",
}
// Lists of known keywords that perform read operations.
readKeywordsList := []string{
"MATCH", "OPTIONAL MATCH", "WITH", "WHERE", "RETURN", "ORDER BY", "SKIP", "LIMIT",
"UNION", "UNION ALL", "UNWIND", "CASE", "WHEN", "THEN", "ELSE", "END",
"SHOW", "PROFILE", "EXPLAIN",
}
// A list of procedure prefixes known to perform write operations.
writeProceduresList := []string{
"apoc.create", "apoc.merge", "apoc.refactor", "apoc.atomic", "apoc.trigger",
"apoc.periodic.commit", "apoc.load.jdbc", "apoc.load.json", "apoc.load.csv",
"apoc.export", "apoc.import", "db.create", "db.drop", "db.index.create",
"db.constraints.create", "dbms.security.create", "gds.graph.create", "gds.graph.drop",
}
// A list of procedure prefixes known to perform read operations.
readProceduresList := []string{
"apoc.meta", "apoc.help", "apoc.version", "apoc.text", "apoc.math", "apoc.coll",
"apoc.path", "apoc.algo", "apoc.date", "db.labels", "db.propertyKeys",
"db.relationshipTypes", "db.schema", "db.indexes", "db.constraints",
"dbms.components", "dbms.listConfig", "gds.graph.list", "gds.util",
}
c.populateKeywords(writeKeywordsList, c.writeKeywords, &c.multiWordWriteKeywords)
c.populateKeywords(readKeywordsList, c.readKeywords, &c.multiWordReadKeywords)
c.populateProcedures(writeProceduresList, c.writeProcedures)
c.populateProcedures(readProceduresList, c.readProcedures)
return c
}
// populateKeywords processes a list of keyword strings, separating them into
// single-word and multi-word lists for easier processing later.
// Multi-word keywords (e.g., "DETACH DELETE") are sorted by length descending
// to ensure longer matches are replaced first.
func (c *QueryClassifier) populateKeywords(keywords []string, keywordMap map[string]struct{}, multiWord *[]string) {
for _, kw := range keywords {
if strings.Contains(kw, " ") {
*multiWord = append(*multiWord, kw)
}
// Replace spaces with underscores for unified tokenization.
keywordMap[strings.ReplaceAll(kw, " ", "_")] = struct{}{}
}
// Sort multi-word keywords by length (longest first) to prevent
// partial matches, e.g., replacing "CREATE OR REPLACE" before "CREATE".
sort.SliceStable(*multiWord, func(i, j int) bool {
return len((*multiWord)[i]) > len((*multiWord)[j])
})
}
// populateProcedures adds a list of procedure prefixes to the given map.
func (c *QueryClassifier) populateProcedures(procedures []string, procedureMap map[string]struct{}) {
for _, proc := range procedures {
procedureMap[strings.ToLower(proc)] = struct{}{}
}
}
// Classify analyzes a Cypher query string and returns a QueryClassification result.
// It is the main method for this package.
//
// The process is as follows:
// 1. Normalize the query by removing comments and extra whitespace.
// 2. Replace string literals to prevent keywords inside them from being classified.
// 3. Unify multi-word keywords (e.g., "DETACH DELETE" becomes "DETACH_DELETE").
// 4. Extract all procedure calls (e.g., `CALL db.labels`).
// 5. Tokenize the remaining query string.
// 6. Check tokens and procedures against known read/write lists.
// 7. If a subquery `CALL { ... }` exists, check its contents for write operations.
// 8. Assign a final classification and confidence score.
//
// Usage example:
//
// classifier := NewQueryClassifier()
// query := "MATCH (n:Person) WHERE n.name = 'Alice' SET n.age = 30"
// result := classifier.Classify(query)
// fmt.Printf("Query is a %s query with confidence %f\n", result.Type, result.Confidence)
// // Output: Query is a WRITE query with confidence 0.900000
// fmt.Printf("Write tokens found: %v\n", result.WriteTokens)
// // Output: Write tokens found: [SET]
func (c *QueryClassifier) Classify(query string) QueryClassification {
result := QueryClassification{
Type: ReadQuery, // Default to read, upgrade to write if write tokens are found.
Confidence: 1.0,
}
normalizedQuery := c.normalizeQuery(query)
if normalizedQuery == "" {
return result // Return default for empty queries.
}
// Early check for subqueries to set the flag.
result.HasSubquery = c.subqueryPattern.MatchString(normalizedQuery)
procedures := c.extractProcedureCalls(normalizedQuery)
// Sanitize the query by replacing string literals to avoid misinterpreting their contents.
sanitizedQuery := c.stringLiteralPattern.ReplaceAllString(normalizedQuery, "STRING_LITERAL")
// Unify multi-word keywords to treat them as single tokens.
unifiedQuery := c.unifyMultiWordKeywords(sanitizedQuery)
tokens := c.extractTokens(unifiedQuery)
// Classify based on standard keywords.
for _, token := range tokens {
upperToken := strings.ToUpper(token)
if _, isWrite := c.writeKeywords[upperToken]; isWrite {
result.WriteTokens = append(result.WriteTokens, upperToken)
result.Type = WriteQuery
} else if _, isRead := c.readKeywords[upperToken]; isRead {
result.ReadTokens = append(result.ReadTokens, upperToken)
}
}
// Classify based on procedure calls.
for _, proc := range procedures {
if c.isWriteProcedure(proc) {
result.WriteTokens = append(result.WriteTokens, "CALL "+proc)
result.Type = WriteQuery
} else if c.isReadProcedure(proc) {
result.ReadTokens = append(result.ReadTokens, "CALL "+proc)
} else {
// CONSERVATIVE APPROACH: If a procedure is not in a known list,
// we guess its type. If it looks like a read (get, list), we treat it as such.
// Otherwise, we assume it's a write operation with lower confidence.
if strings.Contains(proc, ".get") || strings.Contains(proc, ".list") ||
strings.Contains(proc, ".show") || strings.Contains(proc, ".meta") {
result.ReadTokens = append(result.ReadTokens, "CALL "+proc)
} else {
result.WriteTokens = append(result.WriteTokens, "CALL "+proc)
result.Type = WriteQuery
result.Confidence = 0.8 // Lower confidence for unknown procedures.
}
}
}
// If a subquery exists, explicitly check its contents for write operations.
if result.HasSubquery && c.hasWriteInSubquery(unifiedQuery) {
result.Type = WriteQuery
// Add a specific token to indicate the reason for the write classification.
found := false
for _, t := range result.WriteTokens {
if t == "WRITE_IN_SUBQUERY" {
found = true
break
}
}
if !found {
result.WriteTokens = append(result.WriteTokens, "WRITE_IN_SUBQUERY")
}
}
// If a query contains both read and write operations (e.g., MATCH ... DELETE),
// it's a write query. We lower the confidence slightly to reflect the mixed nature.
if len(result.WriteTokens) > 0 && len(result.ReadTokens) > 0 {
result.Confidence = 0.9
}
return result
}
// unifyMultiWordKeywords replaces multi-word keywords in a query with a single,
// underscore-separated token. This simplifies the tokenization process.
// Example: "DETACH DELETE" becomes "DETACH_DELETE".
func (c *QueryClassifier) unifyMultiWordKeywords(query string) string {
upperQuery := strings.ToUpper(query)
// Combine all multi-word keywords for a single pass.
allMultiWord := append(c.multiWordWriteKeywords, c.multiWordReadKeywords...)
for _, kw := range allMultiWord {
placeholder := strings.ReplaceAll(kw, " ", "_")
upperQuery = strings.ReplaceAll(upperQuery, kw, placeholder)
}
return upperQuery
}
// normalizeQuery cleans a query string by removing comments and collapsing
// all whitespace into single spaces.
func (c *QueryClassifier) normalizeQuery(query string) string {
// Remove single-line and multi-line comments.
query = c.commentPattern.ReplaceAllString(query, " ")
// Collapse consecutive whitespace characters into a single space.
query = c.whitespacePattern.ReplaceAllString(query, " ")
return strings.TrimSpace(query)
}
// extractTokens splits a query string into a slice of individual tokens.
// It splits on whitespace and various punctuation marks.
func (c *QueryClassifier) extractTokens(query string) []string {
tokens := c.tokenSplitPattern.Split(query, -1)
// Filter out empty strings that can result from the split.
result := make([]string, 0, len(tokens))
for _, token := range tokens {
if token != "" {
result = append(result, token)
}
}
return result
}
// extractProcedureCalls finds all procedure calls (e.g., `CALL db.labels`)
// in the query and returns a slice of their names.
func (c *QueryClassifier) extractProcedureCalls(query string) []string {
matches := c.procedureCallPattern.FindAllStringSubmatch(query, -1)
procedures := make([]string, 0, len(matches))
for _, match := range matches {
if len(match) > 1 {
procedures = append(procedures, strings.ToLower(match[1]))
}
}
return procedures
}
// isWriteProcedure checks if a given procedure name matches any of the known
// write procedure prefixes.
func (c *QueryClassifier) isWriteProcedure(procedure string) bool {
procedure = strings.ToLower(procedure)
for wp := range c.writeProcedures {
if strings.HasPrefix(procedure, wp) {
return true
}
}
return false
}
// isReadProcedure checks if a given procedure name matches any of the known
// read procedure prefixes.
func (c *QueryClassifier) isReadProcedure(procedure string) bool {
procedure = strings.ToLower(procedure)
for rp := range c.readProcedures {
if strings.HasPrefix(procedure, rp) {
return true
}
}
return false
}
// hasWriteInSubquery detects if a write keyword exists within a `CALL { ... }` block.
// It correctly handles nested braces to find the content of the top-level subquery.
func (c *QueryClassifier) hasWriteInSubquery(unifiedQuery string) bool {
loc := c.subqueryPattern.FindStringIndex(unifiedQuery)
if loc == nil {
return false
}
// The search starts from the beginning of the `CALL {` match.
subqueryContent := unifiedQuery[loc[0]:]
openBraces := 0
startIndex := -1
endIndex := -1
// Find the boundaries of the first complete `{...}` block.
for i, char := range subqueryContent {
if char == '{' {
if openBraces == 0 {
startIndex = i + 1
}
openBraces++
} else if char == '}' {
openBraces--
if openBraces == 0 {
endIndex = i
break
}
}
}
var block string
if startIndex != -1 {
if endIndex != -1 {
// A complete `{...}` block was found.
block = subqueryContent[startIndex:endIndex]
} else {
// An opening brace was found but no closing one; this indicates a
// likely syntax error, but we check the rest of the string anyway.
block = subqueryContent[startIndex:]
}
// Check if any write keyword exists as a whole word within the subquery block.
for writeOp := range c.writeKeywords {
// Use regex to match the keyword as a whole word to avoid partial matches
// (e.g., finding "SET" in "ASSET").
re := regexp.MustCompile(`\b` + writeOp + `\b`)
if re.MatchString(block) {
return true
}
}
}
return false
}
// AddWriteProcedure allows users to dynamically add a custom procedure prefix to the
// list of known write procedures. This is useful for environments with custom plugins.
// The pattern is matched using `strings.HasPrefix`.
//
// Usage example:
//
// classifier := NewQueryClassifier()
// classifier.AddWriteProcedure("my.custom.writer")
// result := classifier.Classify("CALL my.custom.writer.createUser()")
// // result.Type will be WriteQuery
func (c *QueryClassifier) AddWriteProcedure(pattern string) {
if pattern != "" {
c.writeProcedures[strings.ToLower(pattern)] = struct{}{}
}
}
// AddReadProcedure allows users to dynamically add a custom procedure prefix to the
// list of known read procedures.
// The pattern is matched using `strings.HasPrefix`.
//
// Usage example:
//
// classifier := NewQueryClassifier()
// classifier.AddReadProcedure("my.custom.reader")
// result := classifier.Classify("CALL my.custom.reader.getData()")
// // result.Type will be ReadQuery
func (c *QueryClassifier) AddReadProcedure(pattern string) {
if pattern != "" {
c.readProcedures[strings.ToLower(pattern)] = struct{}{}
}
}

View File

@@ -0,0 +1,367 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package classifier
import (
"reflect"
"sort"
"testing"
)
// assertElementsMatch checks if two string slices have the same elements, ignoring order.
// It serves as a replacement for testify's assert.ElementsMatch.
func assertElementsMatch(t *testing.T, expected, actual []string, msg string) {
// t.Helper() marks this function as a test helper.
// When t.Errorf is called from this function, the line number of the calling code is reported, not the line number inside this helper.
t.Helper()
if len(expected) == 0 && len(actual) == 0 {
return // Both are empty or nil, they match.
}
// Create copies to sort, leaving the original slices unmodified.
expectedCopy := make([]string, len(expected))
actualCopy := make([]string, len(actual))
copy(expectedCopy, expected)
copy(actualCopy, actual)
sort.Strings(expectedCopy)
sort.Strings(actualCopy)
// reflect.DeepEqual provides a robust comparison for complex types, including sorted slices.
if !reflect.DeepEqual(expectedCopy, actualCopy) {
t.Errorf("%s: \nexpected: %v\n got: %v", msg, expected, actual)
}
}
func TestQueryClassifier_Classify(t *testing.T) {
classifier := NewQueryClassifier()
tests := []struct {
name string
query string
expectedType QueryType
expectedWrite []string
expectedRead []string
minConfidence float64
}{
// Read queries
{
name: "simple MATCH query",
query: "MATCH (n:Person) RETURN n",
expectedType: ReadQuery,
expectedRead: []string{"MATCH", "RETURN"},
expectedWrite: []string{},
minConfidence: 1.0,
},
{
name: "complex read query",
query: "MATCH (p:Person)-[:KNOWS]->(f) WHERE p.age > 30 RETURN p.name, count(f) ORDER BY p.name SKIP 10 LIMIT 5",
expectedType: ReadQuery,
expectedRead: []string{"MATCH", "WHERE", "RETURN", "ORDER_BY", "SKIP", "LIMIT"},
expectedWrite: []string{},
minConfidence: 1.0,
},
{
name: "UNION query",
query: "MATCH (n:Person) RETURN n.name UNION MATCH (m:Company) RETURN m.name",
expectedType: ReadQuery,
expectedRead: []string{"MATCH", "RETURN", "UNION", "MATCH", "RETURN"},
expectedWrite: []string{},
minConfidence: 1.0,
},
// Write queries
{
name: "CREATE query",
query: "CREATE (n:Person {name: 'John', age: 30})",
expectedType: WriteQuery,
expectedWrite: []string{"CREATE"},
expectedRead: []string{},
minConfidence: 1.0,
},
{
name: "MERGE query",
query: "MERGE (n:Person {id: 123}) ON CREATE SET n.created = timestamp()",
expectedType: WriteQuery,
expectedWrite: []string{"MERGE", "CREATE", "SET"},
expectedRead: []string{},
minConfidence: 1.0,
},
{
name: "DETACH DELETE query",
query: "MATCH (n:Person) DETACH DELETE n",
expectedType: WriteQuery,
expectedWrite: []string{"DETACH_DELETE"},
expectedRead: []string{"MATCH"},
minConfidence: 0.9,
},
// Procedure calls
{
name: "read procedure",
query: "CALL db.labels() YIELD label RETURN label",
expectedType: ReadQuery,
expectedRead: []string{"RETURN", "CALL db.labels"},
expectedWrite: []string{},
minConfidence: 1.0,
},
{
name: "unknown procedure conservative",
query: "CALL custom.procedure.doSomething()",
expectedType: WriteQuery,
expectedWrite: []string{"CALL custom.procedure.dosomething"},
expectedRead: []string{},
minConfidence: 0.8,
},
{
name: "unknown read-like procedure",
query: "CALL custom.procedure.getUsers()",
expectedType: ReadQuery,
expectedRead: []string{"CALL custom.procedure.getusers"},
expectedWrite: []string{},
minConfidence: 1.0,
},
// Subqueries
{
name: "read subquery",
query: "CALL { MATCH (n:Person) RETURN n } RETURN n",
expectedType: ReadQuery,
expectedRead: []string{"MATCH", "RETURN", "RETURN"},
expectedWrite: []string{},
minConfidence: 1.0,
},
{
name: "write subquery",
query: "CALL { CREATE (n:Person) RETURN n } RETURN n",
expectedType: WriteQuery,
expectedWrite: []string{"CREATE", "WRITE_IN_SUBQUERY"},
expectedRead: []string{"RETURN", "RETURN"},
minConfidence: 0.9,
},
// Multiline Queries
{
name: "multiline read query with comments",
query: `
// Find all people and their friends
MATCH (p:Person)-[:KNOWS]->(f:Friend)
/*
Where the person is older than 25
*/
WHERE p.age > 25
RETURN p.name, f.name
`,
expectedType: ReadQuery,
expectedWrite: []string{},
expectedRead: []string{"MATCH", "WHERE", "RETURN"},
minConfidence: 1.0,
},
{
name: "multiline write query",
query: `
MATCH (p:Person {name: 'Alice'})
CREATE (c:Company {name: 'Neo4j'})
CREATE (p)-[:WORKS_FOR]->(c)
`,
expectedType: WriteQuery,
expectedWrite: []string{"CREATE", "CREATE"},
expectedRead: []string{"MATCH"},
minConfidence: 0.9,
},
// Complex Subqueries
{
name: "nested read subquery",
query: `
CALL {
MATCH (p:Person)
RETURN p
}
CALL {
MATCH (c:Company)
RETURN c
}
RETURN p, c
`,
expectedType: ReadQuery,
expectedWrite: []string{},
expectedRead: []string{"MATCH", "RETURN", "MATCH", "RETURN", "RETURN"},
minConfidence: 1.0,
},
{
name: "subquery with write and outer read",
query: `
MATCH (u:User {id: 1})
CALL {
WITH u
CREATE (p:Post {content: 'New post'})
CREATE (u)-[:AUTHORED]->(p)
RETURN p
}
RETURN u.name, p.content
`,
expectedType: WriteQuery,
expectedWrite: []string{"CREATE", "CREATE", "WRITE_IN_SUBQUERY"},
expectedRead: []string{"MATCH", "WITH", "RETURN", "RETURN"},
minConfidence: 0.9,
},
{
name: "subquery with read passing to outer write",
query: `
CALL {
MATCH (p:Product {id: 'abc'})
RETURN p
}
WITH p
SET p.lastViewed = timestamp()
`,
expectedType: WriteQuery,
expectedWrite: []string{"SET"},
expectedRead: []string{"MATCH", "RETURN", "WITH"},
minConfidence: 0.9,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := classifier.Classify(tt.query)
if tt.expectedType != result.Type {
t.Errorf("Query type mismatch: expected %v, got %v", tt.expectedType, result.Type)
}
if result.Confidence < tt.minConfidence {
t.Errorf("Confidence too low: expected at least %f, got %f", tt.minConfidence, result.Confidence)
}
assertElementsMatch(t, tt.expectedWrite, result.WriteTokens, "Write tokens mismatch")
assertElementsMatch(t, tt.expectedRead, result.ReadTokens, "Read tokens mismatch")
})
}
}
func TestQueryClassifier_AbuseCases(t *testing.T) {
classifier := NewQueryClassifier()
tests := []struct {
name string
query string
expectedType QueryType
expectedWrite []string
expectedRead []string
}{
{
name: "write keyword in a string literal",
query: `MATCH (n) WHERE n.name = 'MERGE (m)' RETURN n`,
expectedType: ReadQuery,
expectedWrite: []string{},
expectedRead: []string{"MATCH", "WHERE", "RETURN"},
},
{
name: "incomplete SET clause",
query: `MATCH (n) SET`,
expectedType: WriteQuery,
expectedWrite: []string{"SET"},
expectedRead: []string{"MATCH"},
},
{
name: "keyword as a node label",
query: `MATCH (n:CREATE) RETURN n`,
expectedType: ReadQuery,
expectedWrite: []string{}, // 'CREATE' should be seen as an identifier, not a keyword
expectedRead: []string{"MATCH", "RETURN"},
},
{
name: "unbalanced parentheses",
query: `MATCH (n:Person RETURN n`,
expectedType: ReadQuery,
expectedWrite: []string{},
expectedRead: []string{"MATCH", "RETURN"},
},
{
name: "unclosed curly brace in subquery",
query: `CALL { MATCH (n) CREATE (m)`,
expectedType: WriteQuery,
expectedWrite: []string{"CREATE", "WRITE_IN_SUBQUERY"},
expectedRead: []string{"MATCH"},
},
{
name: "semicolon inside a query part",
query: `MATCH (n;Person) RETURN n`,
expectedType: ReadQuery,
expectedWrite: []string{},
expectedRead: []string{"MATCH", "RETURN"},
},
{
name: "jumbled keywords without proper syntax",
query: `RETURN CREATE MATCH DELETE`,
expectedType: WriteQuery,
// The classifier's job is to find the tokens, not validate the syntax.
// It should find both read and write tokens.
expectedWrite: []string{"CREATE", "DELETE"},
expectedRead: []string{"RETURN", "MATCH"},
},
{
name: "write in a nested subquery",
query: `
CALL {
MATCH (a)
CALL {
CREATE (b:Thing)
}
RETURN a
}
RETURN "done"
`,
expectedType: WriteQuery,
expectedWrite: []string{"CREATE", "WRITE_IN_SUBQUERY"},
expectedRead: []string{"MATCH", "RETURN", "RETURN"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// This defer-recover block ensures the test fails gracefully if the Classify function panics,
// which was the goal of the original assert.NotPanics call.
defer func() {
if r := recover(); r != nil {
t.Fatalf("The code panicked on test '%s': %v", tt.name, r)
}
}()
result := classifier.Classify(tt.query)
if tt.expectedType != result.Type {
t.Errorf("Query type mismatch: expected %v, got %v", tt.expectedType, result.Type)
}
if tt.expectedWrite != nil {
assertElementsMatch(t, tt.expectedWrite, result.WriteTokens, "Write tokens mismatch")
}
if tt.expectedRead != nil {
assertElementsMatch(t, tt.expectedRead, result.ReadTokens, "Read tokens mismatch")
}
})
}
}
func TestNormalizeQuery(t *testing.T) {
classifier := NewQueryClassifier()
t.Run("single line comment", func(t *testing.T) {
input := "MATCH (n) // comment\nRETURN n"
expected := "MATCH (n) RETURN n"
result := classifier.normalizeQuery(input)
if expected != result {
t.Errorf("normalizeQuery failed:\nexpected: %q\n got: %q", expected, result)
}
})
}

View File

@@ -0,0 +1,182 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package neo4jexecutecypher
import (
"context"
"fmt"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
neo4jsc "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jexecutecypher/classifier"
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
)
const kind string = "neo4j-execute-cypher"
func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
}
}
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
return actual, nil
}
type compatibleSource interface {
Neo4jDriver() neo4j.DriverWithContext
Neo4jDatabase() string
}
// validate compatible sources are still compatible
var _ compatibleSource = &neo4jsc.Source{}
var compatibleSources = [...]string{neo4jsc.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description" validate:"required"`
ReadOnly bool `yaml:"readOnly"`
AuthRequired []string `yaml:"authRequired"`
}
// validate interface
var _ tools.ToolConfig = Config{}
func (cfg Config) ToolConfigKind() string {
return kind
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
var s compatibleSource
s, ok = rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
cypherParameter := tools.NewStringParameter("cypher", "The cypher to execute.")
parameters := tools.Parameters{cypherParameter}
mcpManifest := tools.McpManifest{
Name: cfg.Name,
Description: cfg.Description,
InputSchema: parameters.McpManifest(),
}
// finish tool setup
t := Tool{
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
ReadOnly: cfg.ReadOnly,
Driver: s.Neo4jDriver(),
Database: s.Neo4jDatabase(),
classifier: classifier.NewQueryClassifier(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
// validate interface
var _ tools.Tool = Tool{}
type Tool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Parameters tools.Parameters `yaml:"parameters"`
AuthRequired []string `yaml:"authRequired"`
ReadOnly bool `yaml:"readOnly"`
Database string
Driver neo4j.DriverWithContext
classifier *classifier.QueryClassifier
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
sliceParams := params.AsSlice()
cypherStr, ok := sliceParams[0].(string)
if !ok {
return nil, fmt.Errorf("unable to get cast %s", sliceParams[0])
}
if cypherStr == "" {
return nil, fmt.Errorf("parameter 'cypher' must be a non-empty string")
}
// validate the cypher query before executing
cf := t.classifier.Classify(cypherStr)
if cf.Error != nil {
return nil, cf.Error
}
if cf.Type == classifier.WriteQuery && t.ReadOnly {
return nil, fmt.Errorf("this tool is read-only and cannot execute write queries")
}
config := neo4j.ExecuteQueryWithDatabase(t.Database)
results, err := neo4j.ExecuteQuery[*neo4j.EagerResult](ctx, t.Driver, cypherStr, nil,
neo4j.EagerResultTransformer, config)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
var out []any
keys := results.Keys
records := results.Records
for _, record := range records {
vMap := make(map[string]any)
for col, value := range record.Values {
vMap[keys[col]] = value
}
out = append(out, vMap)
}
return out, nil
}
func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) {
return tools.ParseParams(t.Parameters, data, claimsMap)
}
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

View File

@@ -0,0 +1,99 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package neo4jexecutecypher
import (
"testing"
"github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
func TestParseFromYamlNeo4j(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
want server.ToolConfigs
}{
{
desc: "basic example",
in: `
tools:
example_tool:
kind: neo4j-execute-cypher
source: my-neo4j-instance
description: some tool description
authRequired:
- my-google-auth-service
- other-auth-service
`,
want: server.ToolConfigs{
"example_tool": Config{
Name: "example_tool",
Kind: "neo4j-execute-cypher",
Source: "my-neo4j-instance",
Description: "some tool description",
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
},
},
},
{
desc: "readonly example",
in: `
tools:
example_tool:
kind: neo4j-execute-cypher
source: my-neo4j-instance
description: some tool description
readOnly: true
authRequired:
- my-google-auth-service
- other-auth-service
`,
want: server.ToolConfigs{
"example_tool": Config{
Name: "example_tool",
Kind: "neo4j-execute-cypher",
Source: "my-neo4j-instance",
ReadOnly: true,
Description: "some tool description",
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err = yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
}
}

View File

@@ -23,6 +23,7 @@ import (
"os"
"reflect"
"regexp"
"strings"
"testing"
"time"
@@ -78,6 +79,17 @@ func TestNeo4jToolEndpoints(t *testing.T) {
"description": "Simple tool to test end to end functionality.",
"statement": "RETURN 1 as a;",
},
"my-simple-execute-cypher-tool": map[string]any{
"kind": "neo4j-execute-cypher",
"source": "my-neo4j-instance",
"description": "Simple tool to test end to end functionality.",
},
"my-readonly-execute-cypher-tool": map[string]any{
"kind": "neo4j-execute-cypher",
"source": "my-neo4j-instance",
"description": "A readonly cypher execution tool.",
"readOnly": true,
},
},
}
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
@@ -111,6 +123,25 @@ func TestNeo4jToolEndpoints(t *testing.T) {
},
},
},
{
name: "get my-simple-execute-cypher-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-execute-cypher-tool/",
want: map[string]any{
"my-simple-execute-cypher-tool": map[string]any{
"description": "Simple tool to test end to end functionality.",
"parameters": []any{
map[string]any{
"name": "cypher",
"type": "string",
"required": true,
"description": "The cypher to execute.",
"authSources": []any{},
},
},
"authRequired": []any{},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
@@ -141,16 +172,33 @@ func TestNeo4jToolEndpoints(t *testing.T) {
// Test tool invoke endpoint
invokeTcs := []struct {
name string
api string
requestBody io.Reader
want string
name string
api string
requestBody io.Reader
want string
wantStatus int
wantErrorSubstring string
}{
{
name: "invoke my-simple-cypher-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-cypher-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{}`)),
want: "[{\"a\":1}]",
wantStatus: http.StatusOK,
},
{
name: "invoke my-simple-execute-cypher-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-execute-cypher-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{"cypher": "RETURN 1 as a;"}`)),
want: "[{\"a\":1}]",
wantStatus: http.StatusOK,
},
{
name: "invoke readonly tool with write query",
api: "http://127.0.0.1:5000/api/tool/my-readonly-execute-cypher-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{"cypher": "CREATE (n:TestNode)"}`)),
wantStatus: http.StatusBadRequest,
wantErrorSubstring: "this tool is read-only and cannot execute write queries",
},
}
for _, tc := range invokeTcs {
@@ -160,23 +208,34 @@ func TestNeo4jToolEndpoints(t *testing.T) {
t.Fatalf("error when sending a request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if resp.StatusCode != tc.wantStatus {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
t.Fatalf("response status code: got %d, want %d: %s", resp.StatusCode, tc.wantStatus, string(bodyBytes))
}
var body map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&body)
if err != nil {
t.Fatalf("error parsing response body")
}
got, ok := body["result"].(string)
if !ok {
t.Fatalf("unable to find result in response body")
}
if tc.wantStatus == http.StatusOK {
var body map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&body)
if err != nil {
t.Fatalf("error parsing response body")
}
got, ok := body["result"].(string)
if !ok {
t.Fatalf("unable to find result in response body")
}
if got != tc.want {
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
if got != tc.want {
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
}
} else {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read error response body: %s", err)
}
bodyString := string(bodyBytes)
if !strings.Contains(bodyString, tc.wantErrorSubstring) {
t.Fatalf("response body %q does not contain expected error %q", bodyString, tc.wantErrorSubstring)
}
}
})
}