mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 16:08:16 -05:00
feat: Neo4j tools enhancements - neo4j-execute-cypher (#946)
This pull request introduces a new tool for executing arbitrary Cypher queries against a Neo4j database (`neo4j-execute-cypher`) and implements robust query classification functionality to distinguish between read and write operations. The changes include updates to documentation, the addition of a query classifier, and comprehensive test coverage for the classifier. ### Addition of `neo4j-execute-cypher` tool: - **Documentation**: Added a new markdown file `neo4j-execute-cypher.md` that explains the tool's functionality, usage, and configuration options, including the ability to enforce read-only mode for security. - **Import statement**: Registered the new tool in the `cmd/root.go` file to make it available in the toolbox. ### Query classification functionality: - **Query classifier implementation**: Added `QueryClassifier` in `classifier.go`, which classifies Cypher queries into read or write operations based on keywords, procedures, and subquery analysis. It supports handling edge cases like nested subqueries, multi-word keywords, and invalid syntax. - **Test coverage**: Created extensive tests in `classifier_test.go` to validate the classifier's behavior across various query types, including abuse cases, subqueries, and procedure calls. Tests ensure the classifier is robust and does not panic on malformed queries. --------- Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
This commit is contained in:
@@ -64,6 +64,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlsql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jcypher"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jexecutecypher"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/redis"
|
||||
|
||||
53
docs/en/resources/tools/neo4j/neo4j-execute-cypher.md
Normal file
53
docs/en/resources/tools/neo4j/neo4j-execute-cypher.md
Normal file
@@ -0,0 +1,53 @@
|
||||
---
|
||||
title: "neo4j-execute-cypher"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "neo4j-execute-cypher" tool executes any arbitrary Cypher statement against a Neo4j
|
||||
database.
|
||||
aliases:
|
||||
- /resources/tools/neo4j-execute-cypher
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
A `neo4j-execute-cypher` tool executes an arbitrary Cypher query provided as a string parameter against a Neo4j database. It's designed to be a flexible tool for interacting with the database when a pre-defined query is not sufficient. This tool is compatible with any of the following sources:
|
||||
|
||||
- [neo4j](../sources/neo4j.md)
|
||||
|
||||
For security, the tool can be configured to be read-only. If the `readOnly` flag is set to `true`, the tool will analyze the incoming Cypher query and reject any write operations (like `CREATE`, `MERGE`, `DELETE`, etc.) before execution.
|
||||
|
||||
The Cypher query uses standard [Neo4j Cypher](https://neo4j.com/docs/cypher-manual/current/queries/) syntax and supports all Cypher features, including pattern matching, filtering, and aggregation.
|
||||
|
||||
`neo4j-execute-cypher` takes one input parameter `cypher` and run the cypher query against the `source`.
|
||||
|
||||
> **Note:** This tool is intended for developer assistant workflows with
|
||||
> human-in-the-loop and shouldn't be used for production agents.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
query_neo4j:
|
||||
kind: neo4j-execute-cypher
|
||||
source: my-neo4j-prod-db
|
||||
readOnly: true
|
||||
description: |
|
||||
Use this tool to execute a Cypher query against the production database.
|
||||
Only read-only queries are allowed.
|
||||
Takes a single 'cypher' parameter containing the full query string.
|
||||
Example:
|
||||
{{
|
||||
"cypher": "MATCH (m:Movie {title: 'The Matrix'}) RETURN m.released"
|
||||
}}
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:------------------------------------------:|:------------:|-------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "neo4j-cypher". |
|
||||
| source | string | true | Name of the source the Cypher query should execute on. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
| readOnly | boolean | false | If set to `true`, the tool will reject any write operations in the Cypher query. Default is `false`. |
|
||||
|
||||
434
internal/tools/neo4j/neo4jexecutecypher/classifier/classifier.go
Normal file
434
internal/tools/neo4j/neo4jexecutecypher/classifier/classifier.go
Normal file
@@ -0,0 +1,434 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
Package classifier provides tools to classify Cypher queries as either read-only or write operations.
|
||||
|
||||
It uses a keyword-based and procedure-based approach to determine the query's nature.
|
||||
The main entry point is the `Classify` method on a `QueryClassifier` object. The classifier
|
||||
is designed to be conservative, defaulting to classifying unknown procedures as write
|
||||
operations to ensure safety in read-only environments.
|
||||
|
||||
It can handle:
|
||||
- Standard Cypher keywords (MATCH, CREATE, MERGE, etc.).
|
||||
- Multi-word keywords (DETACH DELETE, ORDER BY).
|
||||
- Comments and string literals, which are ignored during classification.
|
||||
- Procedure calls (CALL db.labels), with predefined lists of known read/write procedures.
|
||||
- Subqueries (CALL { ... }), checking for write operations within the subquery block.
|
||||
*/
|
||||
package classifier
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// QueryType represents the classification of a Cypher query as either read or write.
|
||||
type QueryType int
|
||||
|
||||
const (
|
||||
// ReadQuery indicates a query that only reads data.
|
||||
ReadQuery QueryType = iota
|
||||
// WriteQuery indicates a query that modifies data.
|
||||
WriteQuery
|
||||
)
|
||||
|
||||
// String provides a human-readable representation of the QueryType.
|
||||
func (qt QueryType) String() string {
|
||||
if qt == ReadQuery {
|
||||
return "READ"
|
||||
}
|
||||
return "WRITE"
|
||||
}
|
||||
|
||||
// QueryClassification represents the detailed result of a query classification.
|
||||
type QueryClassification struct {
|
||||
// Type is the overall classification of the query (READ or WRITE).
|
||||
Type QueryType
|
||||
// Confidence is a score from 0.0 to 1.0 indicating the classifier's certainty.
|
||||
// 1.0 is fully confident. Lower scores may be assigned for ambiguous cases,
|
||||
// like unknown procedures.
|
||||
Confidence float64
|
||||
// WriteTokens is a list of keywords or procedures found that indicate a write operation.
|
||||
WriteTokens []string
|
||||
// ReadTokens is a list of keywords or procedures found that indicate a read operation.
|
||||
ReadTokens []string
|
||||
// HasSubquery is true if the query contains a `CALL { ... }` block.
|
||||
HasSubquery bool
|
||||
// Error holds any error that occurred during classification, though this is not
|
||||
// currently used in the implementation.
|
||||
Error error
|
||||
}
|
||||
|
||||
// QueryClassifier contains the logic and data for classifying Cypher queries.
|
||||
// It should be instantiated via the NewQueryClassifier() function.
|
||||
type QueryClassifier struct {
|
||||
writeKeywords map[string]struct{}
|
||||
readKeywords map[string]struct{}
|
||||
// writeProcedures is a map of known write procedure prefixes for quick lookup.
|
||||
writeProcedures map[string]struct{}
|
||||
// readProcedures is a map of known read procedure prefixes for quick lookup.
|
||||
readProcedures map[string]struct{}
|
||||
multiWordWriteKeywords []string
|
||||
multiWordReadKeywords []string
|
||||
commentPattern *regexp.Regexp
|
||||
stringLiteralPattern *regexp.Regexp
|
||||
procedureCallPattern *regexp.Regexp
|
||||
subqueryPattern *regexp.Regexp
|
||||
whitespacePattern *regexp.Regexp
|
||||
tokenSplitPattern *regexp.Regexp
|
||||
}
|
||||
|
||||
// NewQueryClassifier creates and initializes a new QueryClassifier instance.
|
||||
// It pre-compiles regular expressions and populates the internal lists of
|
||||
// known Cypher keywords and procedures.
|
||||
func NewQueryClassifier() *QueryClassifier {
|
||||
c := &QueryClassifier{
|
||||
writeKeywords: make(map[string]struct{}),
|
||||
readKeywords: make(map[string]struct{}),
|
||||
writeProcedures: make(map[string]struct{}),
|
||||
readProcedures: make(map[string]struct{}),
|
||||
commentPattern: regexp.MustCompile(`(?m)//.*?$|/\*[\s\S]*?\*/`),
|
||||
stringLiteralPattern: regexp.MustCompile(`'[^']*'|"[^"]*"`),
|
||||
procedureCallPattern: regexp.MustCompile(`(?i)\bCALL\s+([a-zA-Z0-9_.]+)`),
|
||||
subqueryPattern: regexp.MustCompile(`(?i)\bCALL\s*\{`),
|
||||
whitespacePattern: regexp.MustCompile(`\s+`),
|
||||
tokenSplitPattern: regexp.MustCompile(`[\s,(){}[\]]+`),
|
||||
}
|
||||
|
||||
// Lists of known keywords that perform write operations.
|
||||
writeKeywordsList := []string{
|
||||
"CREATE", "MERGE", "DELETE", "DETACH DELETE", "SET", "REMOVE", "FOREACH",
|
||||
"CREATE INDEX", "DROP INDEX", "CREATE CONSTRAINT", "DROP CONSTRAINT",
|
||||
}
|
||||
// Lists of known keywords that perform read operations.
|
||||
readKeywordsList := []string{
|
||||
"MATCH", "OPTIONAL MATCH", "WITH", "WHERE", "RETURN", "ORDER BY", "SKIP", "LIMIT",
|
||||
"UNION", "UNION ALL", "UNWIND", "CASE", "WHEN", "THEN", "ELSE", "END",
|
||||
"SHOW", "PROFILE", "EXPLAIN",
|
||||
}
|
||||
// A list of procedure prefixes known to perform write operations.
|
||||
writeProceduresList := []string{
|
||||
"apoc.create", "apoc.merge", "apoc.refactor", "apoc.atomic", "apoc.trigger",
|
||||
"apoc.periodic.commit", "apoc.load.jdbc", "apoc.load.json", "apoc.load.csv",
|
||||
"apoc.export", "apoc.import", "db.create", "db.drop", "db.index.create",
|
||||
"db.constraints.create", "dbms.security.create", "gds.graph.create", "gds.graph.drop",
|
||||
}
|
||||
// A list of procedure prefixes known to perform read operations.
|
||||
readProceduresList := []string{
|
||||
"apoc.meta", "apoc.help", "apoc.version", "apoc.text", "apoc.math", "apoc.coll",
|
||||
"apoc.path", "apoc.algo", "apoc.date", "db.labels", "db.propertyKeys",
|
||||
"db.relationshipTypes", "db.schema", "db.indexes", "db.constraints",
|
||||
"dbms.components", "dbms.listConfig", "gds.graph.list", "gds.util",
|
||||
}
|
||||
|
||||
c.populateKeywords(writeKeywordsList, c.writeKeywords, &c.multiWordWriteKeywords)
|
||||
c.populateKeywords(readKeywordsList, c.readKeywords, &c.multiWordReadKeywords)
|
||||
c.populateProcedures(writeProceduresList, c.writeProcedures)
|
||||
c.populateProcedures(readProceduresList, c.readProcedures)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// populateKeywords processes a list of keyword strings, separating them into
|
||||
// single-word and multi-word lists for easier processing later.
|
||||
// Multi-word keywords (e.g., "DETACH DELETE") are sorted by length descending
|
||||
// to ensure longer matches are replaced first.
|
||||
func (c *QueryClassifier) populateKeywords(keywords []string, keywordMap map[string]struct{}, multiWord *[]string) {
|
||||
for _, kw := range keywords {
|
||||
if strings.Contains(kw, " ") {
|
||||
*multiWord = append(*multiWord, kw)
|
||||
}
|
||||
// Replace spaces with underscores for unified tokenization.
|
||||
keywordMap[strings.ReplaceAll(kw, " ", "_")] = struct{}{}
|
||||
}
|
||||
// Sort multi-word keywords by length (longest first) to prevent
|
||||
// partial matches, e.g., replacing "CREATE OR REPLACE" before "CREATE".
|
||||
sort.SliceStable(*multiWord, func(i, j int) bool {
|
||||
return len((*multiWord)[i]) > len((*multiWord)[j])
|
||||
})
|
||||
}
|
||||
|
||||
// populateProcedures adds a list of procedure prefixes to the given map.
|
||||
func (c *QueryClassifier) populateProcedures(procedures []string, procedureMap map[string]struct{}) {
|
||||
for _, proc := range procedures {
|
||||
procedureMap[strings.ToLower(proc)] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Classify analyzes a Cypher query string and returns a QueryClassification result.
|
||||
// It is the main method for this package.
|
||||
//
|
||||
// The process is as follows:
|
||||
// 1. Normalize the query by removing comments and extra whitespace.
|
||||
// 2. Replace string literals to prevent keywords inside them from being classified.
|
||||
// 3. Unify multi-word keywords (e.g., "DETACH DELETE" becomes "DETACH_DELETE").
|
||||
// 4. Extract all procedure calls (e.g., `CALL db.labels`).
|
||||
// 5. Tokenize the remaining query string.
|
||||
// 6. Check tokens and procedures against known read/write lists.
|
||||
// 7. If a subquery `CALL { ... }` exists, check its contents for write operations.
|
||||
// 8. Assign a final classification and confidence score.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// classifier := NewQueryClassifier()
|
||||
// query := "MATCH (n:Person) WHERE n.name = 'Alice' SET n.age = 30"
|
||||
// result := classifier.Classify(query)
|
||||
// fmt.Printf("Query is a %s query with confidence %f\n", result.Type, result.Confidence)
|
||||
// // Output: Query is a WRITE query with confidence 0.900000
|
||||
// fmt.Printf("Write tokens found: %v\n", result.WriteTokens)
|
||||
// // Output: Write tokens found: [SET]
|
||||
func (c *QueryClassifier) Classify(query string) QueryClassification {
|
||||
result := QueryClassification{
|
||||
Type: ReadQuery, // Default to read, upgrade to write if write tokens are found.
|
||||
Confidence: 1.0,
|
||||
}
|
||||
|
||||
normalizedQuery := c.normalizeQuery(query)
|
||||
if normalizedQuery == "" {
|
||||
return result // Return default for empty queries.
|
||||
}
|
||||
|
||||
// Early check for subqueries to set the flag.
|
||||
result.HasSubquery = c.subqueryPattern.MatchString(normalizedQuery)
|
||||
procedures := c.extractProcedureCalls(normalizedQuery)
|
||||
|
||||
// Sanitize the query by replacing string literals to avoid misinterpreting their contents.
|
||||
sanitizedQuery := c.stringLiteralPattern.ReplaceAllString(normalizedQuery, "STRING_LITERAL")
|
||||
// Unify multi-word keywords to treat them as single tokens.
|
||||
unifiedQuery := c.unifyMultiWordKeywords(sanitizedQuery)
|
||||
tokens := c.extractTokens(unifiedQuery)
|
||||
|
||||
// Classify based on standard keywords.
|
||||
for _, token := range tokens {
|
||||
upperToken := strings.ToUpper(token)
|
||||
|
||||
if _, isWrite := c.writeKeywords[upperToken]; isWrite {
|
||||
result.WriteTokens = append(result.WriteTokens, upperToken)
|
||||
result.Type = WriteQuery
|
||||
} else if _, isRead := c.readKeywords[upperToken]; isRead {
|
||||
result.ReadTokens = append(result.ReadTokens, upperToken)
|
||||
}
|
||||
}
|
||||
|
||||
// Classify based on procedure calls.
|
||||
for _, proc := range procedures {
|
||||
if c.isWriteProcedure(proc) {
|
||||
result.WriteTokens = append(result.WriteTokens, "CALL "+proc)
|
||||
result.Type = WriteQuery
|
||||
} else if c.isReadProcedure(proc) {
|
||||
result.ReadTokens = append(result.ReadTokens, "CALL "+proc)
|
||||
} else {
|
||||
// CONSERVATIVE APPROACH: If a procedure is not in a known list,
|
||||
// we guess its type. If it looks like a read (get, list), we treat it as such.
|
||||
// Otherwise, we assume it's a write operation with lower confidence.
|
||||
if strings.Contains(proc, ".get") || strings.Contains(proc, ".list") ||
|
||||
strings.Contains(proc, ".show") || strings.Contains(proc, ".meta") {
|
||||
result.ReadTokens = append(result.ReadTokens, "CALL "+proc)
|
||||
} else {
|
||||
result.WriteTokens = append(result.WriteTokens, "CALL "+proc)
|
||||
result.Type = WriteQuery
|
||||
result.Confidence = 0.8 // Lower confidence for unknown procedures.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If a subquery exists, explicitly check its contents for write operations.
|
||||
if result.HasSubquery && c.hasWriteInSubquery(unifiedQuery) {
|
||||
result.Type = WriteQuery
|
||||
// Add a specific token to indicate the reason for the write classification.
|
||||
found := false
|
||||
for _, t := range result.WriteTokens {
|
||||
if t == "WRITE_IN_SUBQUERY" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
result.WriteTokens = append(result.WriteTokens, "WRITE_IN_SUBQUERY")
|
||||
}
|
||||
}
|
||||
|
||||
// If a query contains both read and write operations (e.g., MATCH ... DELETE),
|
||||
// it's a write query. We lower the confidence slightly to reflect the mixed nature.
|
||||
if len(result.WriteTokens) > 0 && len(result.ReadTokens) > 0 {
|
||||
result.Confidence = 0.9
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// unifyMultiWordKeywords replaces multi-word keywords in a query with a single,
|
||||
// underscore-separated token. This simplifies the tokenization process.
|
||||
// Example: "DETACH DELETE" becomes "DETACH_DELETE".
|
||||
func (c *QueryClassifier) unifyMultiWordKeywords(query string) string {
|
||||
upperQuery := strings.ToUpper(query)
|
||||
// Combine all multi-word keywords for a single pass.
|
||||
allMultiWord := append(c.multiWordWriteKeywords, c.multiWordReadKeywords...)
|
||||
|
||||
for _, kw := range allMultiWord {
|
||||
placeholder := strings.ReplaceAll(kw, " ", "_")
|
||||
upperQuery = strings.ReplaceAll(upperQuery, kw, placeholder)
|
||||
}
|
||||
return upperQuery
|
||||
}
|
||||
|
||||
// normalizeQuery cleans a query string by removing comments and collapsing
|
||||
// all whitespace into single spaces.
|
||||
func (c *QueryClassifier) normalizeQuery(query string) string {
|
||||
// Remove single-line and multi-line comments.
|
||||
query = c.commentPattern.ReplaceAllString(query, " ")
|
||||
// Collapse consecutive whitespace characters into a single space.
|
||||
query = c.whitespacePattern.ReplaceAllString(query, " ")
|
||||
return strings.TrimSpace(query)
|
||||
}
|
||||
|
||||
// extractTokens splits a query string into a slice of individual tokens.
|
||||
// It splits on whitespace and various punctuation marks.
|
||||
func (c *QueryClassifier) extractTokens(query string) []string {
|
||||
tokens := c.tokenSplitPattern.Split(query, -1)
|
||||
// Filter out empty strings that can result from the split.
|
||||
result := make([]string, 0, len(tokens))
|
||||
for _, token := range tokens {
|
||||
if token != "" {
|
||||
result = append(result, token)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// extractProcedureCalls finds all procedure calls (e.g., `CALL db.labels`)
|
||||
// in the query and returns a slice of their names.
|
||||
func (c *QueryClassifier) extractProcedureCalls(query string) []string {
|
||||
matches := c.procedureCallPattern.FindAllStringSubmatch(query, -1)
|
||||
procedures := make([]string, 0, len(matches))
|
||||
for _, match := range matches {
|
||||
if len(match) > 1 {
|
||||
procedures = append(procedures, strings.ToLower(match[1]))
|
||||
}
|
||||
}
|
||||
return procedures
|
||||
}
|
||||
|
||||
// isWriteProcedure checks if a given procedure name matches any of the known
|
||||
// write procedure prefixes.
|
||||
func (c *QueryClassifier) isWriteProcedure(procedure string) bool {
|
||||
procedure = strings.ToLower(procedure)
|
||||
for wp := range c.writeProcedures {
|
||||
if strings.HasPrefix(procedure, wp) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isReadProcedure checks if a given procedure name matches any of the known
|
||||
// read procedure prefixes.
|
||||
func (c *QueryClassifier) isReadProcedure(procedure string) bool {
|
||||
procedure = strings.ToLower(procedure)
|
||||
for rp := range c.readProcedures {
|
||||
if strings.HasPrefix(procedure, rp) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// hasWriteInSubquery detects if a write keyword exists within a `CALL { ... }` block.
|
||||
// It correctly handles nested braces to find the content of the top-level subquery.
|
||||
func (c *QueryClassifier) hasWriteInSubquery(unifiedQuery string) bool {
|
||||
loc := c.subqueryPattern.FindStringIndex(unifiedQuery)
|
||||
if loc == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// The search starts from the beginning of the `CALL {` match.
|
||||
subqueryContent := unifiedQuery[loc[0]:]
|
||||
openBraces := 0
|
||||
startIndex := -1
|
||||
endIndex := -1
|
||||
|
||||
// Find the boundaries of the first complete `{...}` block.
|
||||
for i, char := range subqueryContent {
|
||||
if char == '{' {
|
||||
if openBraces == 0 {
|
||||
startIndex = i + 1
|
||||
}
|
||||
openBraces++
|
||||
} else if char == '}' {
|
||||
openBraces--
|
||||
if openBraces == 0 {
|
||||
endIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var block string
|
||||
if startIndex != -1 {
|
||||
if endIndex != -1 {
|
||||
// A complete `{...}` block was found.
|
||||
block = subqueryContent[startIndex:endIndex]
|
||||
} else {
|
||||
// An opening brace was found but no closing one; this indicates a
|
||||
// likely syntax error, but we check the rest of the string anyway.
|
||||
block = subqueryContent[startIndex:]
|
||||
}
|
||||
|
||||
// Check if any write keyword exists as a whole word within the subquery block.
|
||||
for writeOp := range c.writeKeywords {
|
||||
// Use regex to match the keyword as a whole word to avoid partial matches
|
||||
// (e.g., finding "SET" in "ASSET").
|
||||
re := regexp.MustCompile(`\b` + writeOp + `\b`)
|
||||
if re.MatchString(block) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AddWriteProcedure allows users to dynamically add a custom procedure prefix to the
|
||||
// list of known write procedures. This is useful for environments with custom plugins.
|
||||
// The pattern is matched using `strings.HasPrefix`.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// classifier := NewQueryClassifier()
|
||||
// classifier.AddWriteProcedure("my.custom.writer")
|
||||
// result := classifier.Classify("CALL my.custom.writer.createUser()")
|
||||
// // result.Type will be WriteQuery
|
||||
func (c *QueryClassifier) AddWriteProcedure(pattern string) {
|
||||
if pattern != "" {
|
||||
c.writeProcedures[strings.ToLower(pattern)] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// AddReadProcedure allows users to dynamically add a custom procedure prefix to the
|
||||
// list of known read procedures.
|
||||
// The pattern is matched using `strings.HasPrefix`.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// classifier := NewQueryClassifier()
|
||||
// classifier.AddReadProcedure("my.custom.reader")
|
||||
// result := classifier.Classify("CALL my.custom.reader.getData()")
|
||||
// // result.Type will be ReadQuery
|
||||
func (c *QueryClassifier) AddReadProcedure(pattern string) {
|
||||
if pattern != "" {
|
||||
c.readProcedures[strings.ToLower(pattern)] = struct{}{}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,367 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package classifier
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// assertElementsMatch checks if two string slices have the same elements, ignoring order.
|
||||
// It serves as a replacement for testify's assert.ElementsMatch.
|
||||
func assertElementsMatch(t *testing.T, expected, actual []string, msg string) {
|
||||
// t.Helper() marks this function as a test helper.
|
||||
// When t.Errorf is called from this function, the line number of the calling code is reported, not the line number inside this helper.
|
||||
t.Helper()
|
||||
if len(expected) == 0 && len(actual) == 0 {
|
||||
return // Both are empty or nil, they match.
|
||||
}
|
||||
|
||||
// Create copies to sort, leaving the original slices unmodified.
|
||||
expectedCopy := make([]string, len(expected))
|
||||
actualCopy := make([]string, len(actual))
|
||||
copy(expectedCopy, expected)
|
||||
copy(actualCopy, actual)
|
||||
|
||||
sort.Strings(expectedCopy)
|
||||
sort.Strings(actualCopy)
|
||||
|
||||
// reflect.DeepEqual provides a robust comparison for complex types, including sorted slices.
|
||||
if !reflect.DeepEqual(expectedCopy, actualCopy) {
|
||||
t.Errorf("%s: \nexpected: %v\n got: %v", msg, expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryClassifier_Classify(t *testing.T) {
|
||||
classifier := NewQueryClassifier()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
expectedType QueryType
|
||||
expectedWrite []string
|
||||
expectedRead []string
|
||||
minConfidence float64
|
||||
}{
|
||||
// Read queries
|
||||
{
|
||||
name: "simple MATCH query",
|
||||
query: "MATCH (n:Person) RETURN n",
|
||||
expectedType: ReadQuery,
|
||||
expectedRead: []string{"MATCH", "RETURN"},
|
||||
expectedWrite: []string{},
|
||||
minConfidence: 1.0,
|
||||
},
|
||||
{
|
||||
name: "complex read query",
|
||||
query: "MATCH (p:Person)-[:KNOWS]->(f) WHERE p.age > 30 RETURN p.name, count(f) ORDER BY p.name SKIP 10 LIMIT 5",
|
||||
expectedType: ReadQuery,
|
||||
expectedRead: []string{"MATCH", "WHERE", "RETURN", "ORDER_BY", "SKIP", "LIMIT"},
|
||||
expectedWrite: []string{},
|
||||
minConfidence: 1.0,
|
||||
},
|
||||
{
|
||||
name: "UNION query",
|
||||
query: "MATCH (n:Person) RETURN n.name UNION MATCH (m:Company) RETURN m.name",
|
||||
expectedType: ReadQuery,
|
||||
expectedRead: []string{"MATCH", "RETURN", "UNION", "MATCH", "RETURN"},
|
||||
expectedWrite: []string{},
|
||||
minConfidence: 1.0,
|
||||
},
|
||||
|
||||
// Write queries
|
||||
{
|
||||
name: "CREATE query",
|
||||
query: "CREATE (n:Person {name: 'John', age: 30})",
|
||||
expectedType: WriteQuery,
|
||||
expectedWrite: []string{"CREATE"},
|
||||
expectedRead: []string{},
|
||||
minConfidence: 1.0,
|
||||
},
|
||||
{
|
||||
name: "MERGE query",
|
||||
query: "MERGE (n:Person {id: 123}) ON CREATE SET n.created = timestamp()",
|
||||
expectedType: WriteQuery,
|
||||
expectedWrite: []string{"MERGE", "CREATE", "SET"},
|
||||
expectedRead: []string{},
|
||||
minConfidence: 1.0,
|
||||
},
|
||||
{
|
||||
name: "DETACH DELETE query",
|
||||
query: "MATCH (n:Person) DETACH DELETE n",
|
||||
expectedType: WriteQuery,
|
||||
expectedWrite: []string{"DETACH_DELETE"},
|
||||
expectedRead: []string{"MATCH"},
|
||||
minConfidence: 0.9,
|
||||
},
|
||||
|
||||
// Procedure calls
|
||||
{
|
||||
name: "read procedure",
|
||||
query: "CALL db.labels() YIELD label RETURN label",
|
||||
expectedType: ReadQuery,
|
||||
expectedRead: []string{"RETURN", "CALL db.labels"},
|
||||
expectedWrite: []string{},
|
||||
minConfidence: 1.0,
|
||||
},
|
||||
{
|
||||
name: "unknown procedure conservative",
|
||||
query: "CALL custom.procedure.doSomething()",
|
||||
expectedType: WriteQuery,
|
||||
expectedWrite: []string{"CALL custom.procedure.dosomething"},
|
||||
expectedRead: []string{},
|
||||
minConfidence: 0.8,
|
||||
},
|
||||
{
|
||||
name: "unknown read-like procedure",
|
||||
query: "CALL custom.procedure.getUsers()",
|
||||
expectedType: ReadQuery,
|
||||
expectedRead: []string{"CALL custom.procedure.getusers"},
|
||||
expectedWrite: []string{},
|
||||
minConfidence: 1.0,
|
||||
},
|
||||
|
||||
// Subqueries
|
||||
{
|
||||
name: "read subquery",
|
||||
query: "CALL { MATCH (n:Person) RETURN n } RETURN n",
|
||||
expectedType: ReadQuery,
|
||||
expectedRead: []string{"MATCH", "RETURN", "RETURN"},
|
||||
expectedWrite: []string{},
|
||||
minConfidence: 1.0,
|
||||
},
|
||||
{
|
||||
name: "write subquery",
|
||||
query: "CALL { CREATE (n:Person) RETURN n } RETURN n",
|
||||
expectedType: WriteQuery,
|
||||
expectedWrite: []string{"CREATE", "WRITE_IN_SUBQUERY"},
|
||||
expectedRead: []string{"RETURN", "RETURN"},
|
||||
minConfidence: 0.9,
|
||||
},
|
||||
|
||||
// Multiline Queries
|
||||
{
|
||||
name: "multiline read query with comments",
|
||||
query: `
|
||||
// Find all people and their friends
|
||||
MATCH (p:Person)-[:KNOWS]->(f:Friend)
|
||||
/*
|
||||
Where the person is older than 25
|
||||
*/
|
||||
WHERE p.age > 25
|
||||
RETURN p.name, f.name
|
||||
`,
|
||||
expectedType: ReadQuery,
|
||||
expectedWrite: []string{},
|
||||
expectedRead: []string{"MATCH", "WHERE", "RETURN"},
|
||||
minConfidence: 1.0,
|
||||
},
|
||||
{
|
||||
name: "multiline write query",
|
||||
query: `
|
||||
MATCH (p:Person {name: 'Alice'})
|
||||
CREATE (c:Company {name: 'Neo4j'})
|
||||
CREATE (p)-[:WORKS_FOR]->(c)
|
||||
`,
|
||||
expectedType: WriteQuery,
|
||||
expectedWrite: []string{"CREATE", "CREATE"},
|
||||
expectedRead: []string{"MATCH"},
|
||||
minConfidence: 0.9,
|
||||
},
|
||||
|
||||
// Complex Subqueries
|
||||
{
|
||||
name: "nested read subquery",
|
||||
query: `
|
||||
CALL {
|
||||
MATCH (p:Person)
|
||||
RETURN p
|
||||
}
|
||||
CALL {
|
||||
MATCH (c:Company)
|
||||
RETURN c
|
||||
}
|
||||
RETURN p, c
|
||||
`,
|
||||
expectedType: ReadQuery,
|
||||
expectedWrite: []string{},
|
||||
expectedRead: []string{"MATCH", "RETURN", "MATCH", "RETURN", "RETURN"},
|
||||
minConfidence: 1.0,
|
||||
},
|
||||
{
|
||||
name: "subquery with write and outer read",
|
||||
query: `
|
||||
MATCH (u:User {id: 1})
|
||||
CALL {
|
||||
WITH u
|
||||
CREATE (p:Post {content: 'New post'})
|
||||
CREATE (u)-[:AUTHORED]->(p)
|
||||
RETURN p
|
||||
}
|
||||
RETURN u.name, p.content
|
||||
`,
|
||||
expectedType: WriteQuery,
|
||||
expectedWrite: []string{"CREATE", "CREATE", "WRITE_IN_SUBQUERY"},
|
||||
expectedRead: []string{"MATCH", "WITH", "RETURN", "RETURN"},
|
||||
minConfidence: 0.9,
|
||||
},
|
||||
{
|
||||
name: "subquery with read passing to outer write",
|
||||
query: `
|
||||
CALL {
|
||||
MATCH (p:Product {id: 'abc'})
|
||||
RETURN p
|
||||
}
|
||||
WITH p
|
||||
SET p.lastViewed = timestamp()
|
||||
`,
|
||||
expectedType: WriteQuery,
|
||||
expectedWrite: []string{"SET"},
|
||||
expectedRead: []string{"MATCH", "RETURN", "WITH"},
|
||||
minConfidence: 0.9,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := classifier.Classify(tt.query)
|
||||
|
||||
if tt.expectedType != result.Type {
|
||||
t.Errorf("Query type mismatch: expected %v, got %v", tt.expectedType, result.Type)
|
||||
}
|
||||
if result.Confidence < tt.minConfidence {
|
||||
t.Errorf("Confidence too low: expected at least %f, got %f", tt.minConfidence, result.Confidence)
|
||||
}
|
||||
assertElementsMatch(t, tt.expectedWrite, result.WriteTokens, "Write tokens mismatch")
|
||||
assertElementsMatch(t, tt.expectedRead, result.ReadTokens, "Read tokens mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryClassifier_AbuseCases(t *testing.T) {
|
||||
classifier := NewQueryClassifier()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
expectedType QueryType
|
||||
expectedWrite []string
|
||||
expectedRead []string
|
||||
}{
|
||||
{
|
||||
name: "write keyword in a string literal",
|
||||
query: `MATCH (n) WHERE n.name = 'MERGE (m)' RETURN n`,
|
||||
expectedType: ReadQuery,
|
||||
expectedWrite: []string{},
|
||||
expectedRead: []string{"MATCH", "WHERE", "RETURN"},
|
||||
},
|
||||
{
|
||||
name: "incomplete SET clause",
|
||||
query: `MATCH (n) SET`,
|
||||
expectedType: WriteQuery,
|
||||
expectedWrite: []string{"SET"},
|
||||
expectedRead: []string{"MATCH"},
|
||||
},
|
||||
{
|
||||
name: "keyword as a node label",
|
||||
query: `MATCH (n:CREATE) RETURN n`,
|
||||
expectedType: ReadQuery,
|
||||
expectedWrite: []string{}, // 'CREATE' should be seen as an identifier, not a keyword
|
||||
expectedRead: []string{"MATCH", "RETURN"},
|
||||
},
|
||||
{
|
||||
name: "unbalanced parentheses",
|
||||
query: `MATCH (n:Person RETURN n`,
|
||||
expectedType: ReadQuery,
|
||||
expectedWrite: []string{},
|
||||
expectedRead: []string{"MATCH", "RETURN"},
|
||||
},
|
||||
{
|
||||
name: "unclosed curly brace in subquery",
|
||||
query: `CALL { MATCH (n) CREATE (m)`,
|
||||
expectedType: WriteQuery,
|
||||
expectedWrite: []string{"CREATE", "WRITE_IN_SUBQUERY"},
|
||||
expectedRead: []string{"MATCH"},
|
||||
},
|
||||
{
|
||||
name: "semicolon inside a query part",
|
||||
query: `MATCH (n;Person) RETURN n`,
|
||||
expectedType: ReadQuery,
|
||||
expectedWrite: []string{},
|
||||
expectedRead: []string{"MATCH", "RETURN"},
|
||||
},
|
||||
{
|
||||
name: "jumbled keywords without proper syntax",
|
||||
query: `RETURN CREATE MATCH DELETE`,
|
||||
expectedType: WriteQuery,
|
||||
// The classifier's job is to find the tokens, not validate the syntax.
|
||||
// It should find both read and write tokens.
|
||||
expectedWrite: []string{"CREATE", "DELETE"},
|
||||
expectedRead: []string{"RETURN", "MATCH"},
|
||||
},
|
||||
{
|
||||
name: "write in a nested subquery",
|
||||
query: `
|
||||
CALL {
|
||||
MATCH (a)
|
||||
CALL {
|
||||
CREATE (b:Thing)
|
||||
}
|
||||
RETURN a
|
||||
}
|
||||
RETURN "done"
|
||||
`,
|
||||
expectedType: WriteQuery,
|
||||
expectedWrite: []string{"CREATE", "WRITE_IN_SUBQUERY"},
|
||||
expectedRead: []string{"MATCH", "RETURN", "RETURN"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// This defer-recover block ensures the test fails gracefully if the Classify function panics,
|
||||
// which was the goal of the original assert.NotPanics call.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("The code panicked on test '%s': %v", tt.name, r)
|
||||
}
|
||||
}()
|
||||
|
||||
result := classifier.Classify(tt.query)
|
||||
if tt.expectedType != result.Type {
|
||||
t.Errorf("Query type mismatch: expected %v, got %v", tt.expectedType, result.Type)
|
||||
}
|
||||
if tt.expectedWrite != nil {
|
||||
assertElementsMatch(t, tt.expectedWrite, result.WriteTokens, "Write tokens mismatch")
|
||||
}
|
||||
if tt.expectedRead != nil {
|
||||
assertElementsMatch(t, tt.expectedRead, result.ReadTokens, "Read tokens mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQuery(t *testing.T) {
|
||||
classifier := NewQueryClassifier()
|
||||
t.Run("single line comment", func(t *testing.T) {
|
||||
input := "MATCH (n) // comment\nRETURN n"
|
||||
expected := "MATCH (n) RETURN n"
|
||||
result := classifier.normalizeQuery(input)
|
||||
if expected != result {
|
||||
t.Errorf("normalizeQuery failed:\nexpected: %q\n got: %q", expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
182
internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go
Normal file
182
internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go
Normal file
@@ -0,0 +1,182 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package neo4jexecutecypher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
neo4jsc "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jexecutecypher/classifier"
|
||||
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
|
||||
)
|
||||
|
||||
const kind string = "neo4j-execute-cypher"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Neo4jDriver() neo4j.DriverWithContext
|
||||
Neo4jDatabase() string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &neo4jsc.Source{}
|
||||
|
||||
var compatibleSources = [...]string{neo4jsc.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
ReadOnly bool `yaml:"readOnly"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
var s compatibleSource
|
||||
s, ok = rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
cypherParameter := tools.NewStringParameter("cypher", "The cypher to execute.")
|
||||
parameters := tools.Parameters{cypherParameter}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: parameters.McpManifest(),
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
ReadOnly: cfg.ReadOnly,
|
||||
Driver: s.Neo4jDriver(),
|
||||
Database: s.Neo4jDatabase(),
|
||||
classifier: classifier.NewQueryClassifier(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
ReadOnly bool `yaml:"readOnly"`
|
||||
Database string
|
||||
Driver neo4j.DriverWithContext
|
||||
classifier *classifier.QueryClassifier
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
sliceParams := params.AsSlice()
|
||||
cypherStr, ok := sliceParams[0].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to get cast %s", sliceParams[0])
|
||||
}
|
||||
|
||||
if cypherStr == "" {
|
||||
return nil, fmt.Errorf("parameter 'cypher' must be a non-empty string")
|
||||
}
|
||||
|
||||
// validate the cypher query before executing
|
||||
cf := t.classifier.Classify(cypherStr)
|
||||
if cf.Error != nil {
|
||||
return nil, cf.Error
|
||||
}
|
||||
|
||||
if cf.Type == classifier.WriteQuery && t.ReadOnly {
|
||||
return nil, fmt.Errorf("this tool is read-only and cannot execute write queries")
|
||||
}
|
||||
|
||||
config := neo4j.ExecuteQueryWithDatabase(t.Database)
|
||||
results, err := neo4j.ExecuteQuery[*neo4j.EagerResult](ctx, t.Driver, cypherStr, nil,
|
||||
neo4j.EagerResultTransformer, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
var out []any
|
||||
keys := results.Keys
|
||||
records := results.Records
|
||||
for _, record := range records {
|
||||
vMap := make(map[string]any)
|
||||
for col, value := range record.Values {
|
||||
vMap[keys[col]] = value
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claimsMap)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package neo4jexecutecypher
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlNeo4j(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: neo4j-execute-cypher
|
||||
source: my-neo4j-instance
|
||||
description: some tool description
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": Config{
|
||||
Name: "example_tool",
|
||||
Kind: "neo4j-execute-cypher",
|
||||
Source: "my-neo4j-instance",
|
||||
Description: "some tool description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "readonly example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: neo4j-execute-cypher
|
||||
source: my-neo4j-instance
|
||||
description: some tool description
|
||||
readOnly: true
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": Config{
|
||||
Name: "example_tool",
|
||||
Kind: "neo4j-execute-cypher",
|
||||
Source: "my-neo4j-instance",
|
||||
ReadOnly: true,
|
||||
Description: "some tool description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err = yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -78,6 +79,17 @@ func TestNeo4jToolEndpoints(t *testing.T) {
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"statement": "RETURN 1 as a;",
|
||||
},
|
||||
"my-simple-execute-cypher-tool": map[string]any{
|
||||
"kind": "neo4j-execute-cypher",
|
||||
"source": "my-neo4j-instance",
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
},
|
||||
"my-readonly-execute-cypher-tool": map[string]any{
|
||||
"kind": "neo4j-execute-cypher",
|
||||
"source": "my-neo4j-instance",
|
||||
"description": "A readonly cypher execution tool.",
|
||||
"readOnly": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
@@ -111,6 +123,25 @@ func TestNeo4jToolEndpoints(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "get my-simple-execute-cypher-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-execute-cypher-tool/",
|
||||
want: map[string]any{
|
||||
"my-simple-execute-cypher-tool": map[string]any{
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"name": "cypher",
|
||||
"type": "string",
|
||||
"required": true,
|
||||
"description": "The cypher to execute.",
|
||||
"authSources": []any{},
|
||||
},
|
||||
},
|
||||
"authRequired": []any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -141,16 +172,33 @@ func TestNeo4jToolEndpoints(t *testing.T) {
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
name string
|
||||
api string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
wantStatus int
|
||||
wantErrorSubstring string
|
||||
}{
|
||||
{
|
||||
name: "invoke my-simple-cypher-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-cypher-tool/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
want: "[{\"a\":1}]",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invoke my-simple-execute-cypher-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-execute-cypher-tool/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"cypher": "RETURN 1 as a;"}`)),
|
||||
want: "[{\"a\":1}]",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invoke readonly tool with write query",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-readonly-execute-cypher-tool/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"cypher": "CREATE (n:TestNode)"}`)),
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErrorSubstring: "this tool is read-only and cannot execute write queries",
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
@@ -160,23 +208,34 @@ func TestNeo4jToolEndpoints(t *testing.T) {
|
||||
t.Fatalf("error when sending a request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if resp.StatusCode != tc.wantStatus {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
t.Fatalf("response status code: got %d, want %d: %s", resp.StatusCode, tc.wantStatus, string(bodyBytes))
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
if tc.wantStatus == http.StatusOK {
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
} else {
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read error response body: %s", err)
|
||||
}
|
||||
bodyString := string(bodyBytes)
|
||||
if !strings.Contains(bodyString, tc.wantErrorSubstring) {
|
||||
t.Fatalf("response body %q does not contain expected error %q", bodyString, tc.wantErrorSubstring)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user